Skip to content

Commit

Permalink
style: clean up some print code
Browse files Browse the repository at this point in the history
  • Loading branch information
Rexwang8 committed May 19, 2024
1 parent 741198f commit 96eb515
Showing 1 changed file with 39 additions and 50 deletions.
89 changes: 39 additions & 50 deletions gpt_bpe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,10 +273,10 @@ func TestGPTEncoder_Split(t *testing.T) {
func BenchmarkGPTEncoder_WordSplitterChan(b *testing.B) {
b.StopTimer()
corpusHandle, err := os.Open(largeCorpusPath)
defer corpusHandle.Close()
if err != nil {
b.Error(err)
}
defer corpusHandle.Close()
gpt2Encoder.SplitterThreads = 8
nextWord := gpt2Encoder.WordSplitter(bufio.NewReaderSize(corpusHandle,
8*1024*1024))
Expand Down Expand Up @@ -413,8 +413,8 @@ func BenchmarkGPTEncoder_Decode(b *testing.B) {
start := time.Now()
tokenNumBytes := len(gpt2Encoder.Decode(gpt2Encoded))
duration := time.Since(start)
b.Log(fmt.Sprintf("%v tokens into %v bytes over %v",
len(*gpt2Encoded), tokenNumBytes, duration))
b.Logf("%v tokens into %v bytes over %v",
len(*gpt2Encoded), tokenNumBytes, duration)
}

type EncoderTest struct {
Expand Down Expand Up @@ -453,25 +453,25 @@ func BenchmarkGPTEncoder_Encode(b *testing.B) {
start := time.Now()
tokenCt := len(*gpt2Encoder.Encode(&corpus))
duration := time.Since(start)
b.Log(fmt.Sprintf("%v bytes into %v tokens over %v",
len(corpus), tokenCt, duration))
b.Logf("%v bytes into %v tokens over %v",
len(corpus), tokenCt, duration)
}

func BenchmarkGPTEncoder_EncodeBuffer(b *testing.B) {
corpusBytes := []byte(corpus)
start := time.Now()
tokenCt := len(*gpt2Encoder.EncodeBuffer(&corpusBytes)) / 2
duration := time.Since(start)
b.Log(fmt.Sprintf("%v bytes into %v tokens over %v",
len(corpus), tokenCt, duration))
b.Logf("%v bytes into %v tokens over %v",
len(corpus), tokenCt, duration)
}

func TestGPTEncoder_Encode(t *testing.T) {
start := time.Now()
tokenCt := len(*gpt2Encoder.Encode(&corpus))
duration := time.Since(start)
t.Log(fmt.Sprintf("%v bytes into %v tokens over %v",
len(corpus), tokenCt, duration))
t.Logf("%v bytes into %v tokens over %v",
len(corpus), tokenCt, duration)
for testIdx := range GPTEncoderTests {
tokensPtr := *gpt2Encoder.Encode(
&(GPTEncoderTests[testIdx].Input))
Expand All @@ -492,19 +492,18 @@ func TestGPTEncoder_StreamingEncode(t *testing.T) {
tokenCt += len(*tokens)
}
duration := time.Since(start)
t.Log(fmt.Sprintf("%v bytes into %v tokens over %v",
len(corpus), tokenCt, duration))
t.Logf("%v bytes into %v tokens over %v",
len(corpus), tokenCt, duration)
}

func TestCLIPEncoder_Encode(t *testing.T) {
start := time.Now()
tokenCt := len(*clipEncoder.Encode(&corpus))
duration := time.Since(start)
t.Log(fmt.Sprintf("%v bytes into %v tokens over %v",
len(corpus), tokenCt, duration))
t.Logf("%v bytes into %v tokens over %v",
len(corpus), tokenCt, duration)
for testIdx := range GPTEncoderTests {
testStr := fmt.Sprintf("%s",
GPTEncoderTests[testIdx].Input)
testStr := GPTEncoderTests[testIdx].Input
tokensPtr := *clipEncoder.Encode(&testStr)
assert.Equal(t, GPTEncoderTests[testIdx].CLIPExpected, tokensPtr)
}
Expand All @@ -514,8 +513,8 @@ func TestPileEncoder_Encode(t *testing.T) {
start := time.Now()
tokenCt := len(*pileEncoder.Encode(&corpus))
duration := time.Since(start)
t.Log(fmt.Sprintf("%v bytes into %v tokens over %v",
len(corpus), tokenCt, duration))
t.Logf("%v bytes into %v tokens over %v",
len(corpus), tokenCt, duration)
for testIdx := range GPTEncoderTests {
tokensPtr := *pileEncoder.Encode(
&(GPTEncoderTests[testIdx].Input))
Expand All @@ -527,8 +526,8 @@ func TestNerdstashEncoder_Encode(t *testing.T) {
start := time.Now()
tokenCt := len(*nerdstashV2Encoder.Encode(&corpus))
duration := time.Since(start)
t.Log(fmt.Sprintf("%v bytes into %v tokens over %v",
len(corpus), tokenCt, duration))
t.Logf("%v bytes into %v tokens over %v",
len(corpus), tokenCt, duration)
for testIdx := range GPTEncoderTests {
tokensPtr := *nerdstashV2Encoder.Encode(
&(GPTEncoderTests[testIdx].Input))
Expand Down Expand Up @@ -576,7 +575,7 @@ func TestNerdstashEncoder_Encode2(t *testing.T) {
encoded := nerdstashV2Encoder.Encode(&inputStr)
// check that the encoded string is the same as the expected
if !assert.Equal(t, expected, *encoded) {
t.Log(fmt.Sprintf("failure on input: `%v`", inputStr))
t.Logf("failure on input: `%v`", inputStr)
expectedRepr := []string{}
for _, token := range expected {
expectedRepr = append(expectedRepr,
Expand All @@ -587,14 +586,14 @@ func TestNerdstashEncoder_Encode2(t *testing.T) {
actualRepr = append(actualRepr,
string(nerdstashV2Encoder.Decoder[token]))
}
t.Log(fmt.Sprintf("expected: |%s", strings.Join(expectedRepr, "|")))
t.Log(fmt.Sprintf("actual: |%s", strings.Join(actualRepr, "|")))
t.Logf("expected: |%s", strings.Join(expectedRepr, "|"))
t.Logf("actual: |%s", strings.Join(actualRepr, "|"))
failCt += 1
} else {
passCt += 1
}
}
t.Log(fmt.Sprintf("pass: %v, fail: %v", passCt, failCt))
t.Logf("pass: %v, fail: %v", passCt, failCt)
}

func TestNerdstashEncoder_Decode(t *testing.T) {
Expand All @@ -613,6 +612,9 @@ func TestGPTEncoder_Decode2(t *testing.T) {
} else {
tokens := TokensFromBin(&binTokens)
tokens, err = gpt2Encoder.TrimIncompleteSentence(tokens)
if err != nil {
t.Error(err)
}
assert.Equal(t, gpt2Encoder.Decode(tokens), decodedCorpus)
}
}
Expand All @@ -626,8 +628,8 @@ func TestGPTEncoder_Decode(t *testing.T) {
decoded := gpt2Encoder.Decode(gpt2Encoded)
duration := time.Since(start)
tokenNumBytes := len(decoded)
t.Log(fmt.Sprintf("%v tokens into %v bytes over %v\n",
len(*gpt2Encoded), tokenNumBytes, duration))
t.Logf("%v tokens into %v bytes over %v\n",
len(*gpt2Encoded), tokenNumBytes, duration)
assert.Equal(t, corpus, decoded)
}

Expand All @@ -647,8 +649,7 @@ func TestCLIPEncoder_Decode(t *testing.T) {
duration := time.Since(start)
tokenNumBytes := len(decoded)
idxToStop := 229550
t.Log(fmt.Sprintf("%v tokens into %v bytes over %v\n",
len(*clipEncoded), tokenNumBytes, duration))
t.Logf("%v tokens into %v bytes over %v\n", len(*clipEncoded), tokenNumBytes, duration)
for idx := range clipCorpus {
if idx > idxToStop {
break
Expand All @@ -672,8 +673,8 @@ func TestPileEncoder_Decode(t *testing.T) {
decoded := pileEncoder.Decode(pileEncoded)
duration := time.Since(start)
tokenNumBytes := len(decoded)
t.Log(fmt.Sprintf("%v tokens into %v bytes over %v\n",
len(*pileEncoded), tokenNumBytes, duration))
t.Logf("%v tokens into %v bytes over %v\n",
len(*pileEncoded), tokenNumBytes, duration)
range_data := corpus
if len(corpus) > len(decoded) {
range_data = decoded
Expand Down Expand Up @@ -779,8 +780,8 @@ func TestLlamaEncoder_Encode(t *testing.T) {
start := time.Now()
tokenCt := len(*gpt2Encoder.Encode(&corpus))
duration := time.Since(start)
t.Log(fmt.Sprintf("%v bytes into %v tokens over %v",
len(corpus), tokenCt, duration))
t.Logf("%v bytes into %v tokens over %v",
len(corpus), tokenCt, duration)
for testIdx := range GPTEncoderTests {
tokensPtr := *gpt2Encoder.Encode(
&(GPTEncoderTests[testIdx].Input))
Expand Down Expand Up @@ -852,9 +853,7 @@ func TestReadTokenizerConfig(t *testing.T) {
destPath := "./TestReadTokenizerConfig"
destPathPTR := &destPath
defer os.RemoveAll(destPath)
var rsrcType resources.ResourceType
rsrcType = resources.RESOURCETYPE_TRANSFORMERS
hfApiToken := os.Getenv("HF_API_TOKEN")
rsrcType, hfApiToken := resources.RESOURCETYPE_TRANSFORMERS, os.Getenv("HF_API_TOKEN")
os.MkdirAll(destPath, 0755)
_, rsrcErr := resources.ResolveResources(modelId, destPathPTR,
resources.RESOURCE_MODEL, rsrcType, hfApiToken)
Expand Down Expand Up @@ -898,9 +897,7 @@ func TestModelDownload(t *testing.T) {
destPath := "./TestModelDownload"
destPathPTR := &destPath

var rsrcType resources.ResourceType
rsrcType = resources.RESOURCETYPE_TRANSFORMERS
hfApiToken := os.Getenv("HF_API_TOKEN")
rsrcType, hfApiToken := resources.RESOURCETYPE_TRANSFORMERS, os.Getenv("HF_API_TOKEN")
os.MkdirAll(destPath, 0755)
defer os.RemoveAll(destPath)
_, rsrcErr := resources.ResolveResources(modelId, destPathPTR,
Expand Down Expand Up @@ -975,9 +972,7 @@ func TestModelDownloadPythia(t *testing.T) {
destPath := "./TestModelDownloadPythia"
destPathPTR := &destPath

var rsrcType resources.ResourceType
rsrcType = resources.RESOURCETYPE_TRANSFORMERS
hfApiToken := os.Getenv("HF_API_TOKEN")
rsrcType, hfApiToken := resources.RESOURCETYPE_TRANSFORMERS, os.Getenv("HF_API_TOKEN")
os.MkdirAll(destPath, 0755)
defer os.RemoveAll(destPath)
_, rsrcErr := resources.ResolveResources(modelId, destPathPTR,
Expand Down Expand Up @@ -1051,9 +1046,7 @@ func TestModelDownloadPythiaSharded(t *testing.T) {
destPath := "./TestModelDownloadPythiaSharded"
destPathPTR := &destPath

var rsrcType resources.ResourceType
rsrcType = resources.RESOURCETYPE_TRANSFORMERS
hfApiToken := os.Getenv("HF_API_TOKEN")
rsrcType, hfApiToken := resources.RESOURCETYPE_TRANSFORMERS, os.Getenv("HF_API_TOKEN")
os.MkdirAll(destPath, 0755)
defer os.RemoveAll(destPath)
_, rsrcErr := resources.ResolveResources(modelId, destPathPTR,
Expand Down Expand Up @@ -1118,9 +1111,7 @@ func TestModelDownloadLlama(t *testing.T) {
destPathPTR := &destPath
defer os.RemoveAll(destPath)

var rsrcType resources.ResourceType
rsrcType = resources.RESOURCETYPE_TRANSFORMERS
hfApiToken := os.Getenv("HF_API_TOKEN")
rsrcType, hfApiToken := resources.RESOURCETYPE_TRANSFORMERS, os.Getenv("HF_API_TOKEN")
os.MkdirAll(destPath, 0755)
_, rsrcErr := resources.ResolveResources(modelId, destPathPTR,
resources.RESOURCE_MODEL, rsrcType, hfApiToken)
Expand Down Expand Up @@ -1166,7 +1157,7 @@ func TestModelDownloadLlama(t *testing.T) {
}

matches := re.FindStringSubmatch(file.Name())
if matches != nil && len(matches) > 2 {
if len(matches) > 2 {
if strings.Compare(matches[1], matches[2]) == 0 {
found = true
break
Expand Down Expand Up @@ -1212,9 +1203,7 @@ func TestModelDownloadFairseq(t *testing.T) {
destPath := "./TestModelDownloadFairseq"
destPathPTR := &destPath

var rsrcType resources.ResourceType
rsrcType = resources.RESOURCETYPE_TRANSFORMERS
hfApiToken := os.Getenv("HF_API_TOKEN")
rsrcType, hfApiToken := resources.RESOURCETYPE_TRANSFORMERS, os.Getenv("HF_API_TOKEN")
os.MkdirAll(destPath, 0755)
defer os.RemoveAll(destPath)
_, rsrcErr := resources.ResolveResources(modelId, destPathPTR,
Expand Down

0 comments on commit 96eb515

Please sign in to comment.