Skip to content

Commit

Permalink
Revert "test: added defering of deletion of testing dirs"
Browse files Browse the repository at this point in the history
This reverts commit c1788e5.
  • Loading branch information
Rexwang8 committed May 17, 2024
1 parent c1788e5 commit 57d55c4
Showing 1 changed file with 129 additions and 15 deletions.
144 changes: 129 additions & 15 deletions gpt_bpe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ var clipEncoder GPTEncoder
var gpt2Encoder GPTEncoder
var pileEncoder GPTEncoder
var nerdstashV2Encoder GPTEncoder
var Llama2Encoder GPTEncoder
var llama2Encoder GPTEncoder
var mistralEncoder GPTEncoder
var corpus string
var clipCorpus string

Expand All @@ -33,7 +34,8 @@ var gpt2Encoded *Tokens
var pileEncoded *Tokens
var clipEncoded *Tokens
var nerdstashEncoded *Tokens
var Llama2Encoded *Tokens
var llama2Encoded *Tokens
var mistralEncoded *Tokens
var unicodeTrimTests []*Tokens

const largeCorpusPath = "resources/wiki.train.raw"
Expand Down Expand Up @@ -93,7 +95,8 @@ func init() {
pileEncoder = NewPileEncoder()
clipEncoder = NewCLIPEncoder()
nerdstashV2Encoder = NewNerdstashV2Encoder()
Llama2Encoder = NewLlama2Encoder()
llama2Encoder = NewLlama2Encoder()
mistralEncoder = NewMistralEncoder()
textBytes := handleRead("resources/frankenstein.txt")
clipBytes := handleRead("resources/frankenstein_clip.txt")
corpus = string(textBytes)
Expand Down Expand Up @@ -787,25 +790,102 @@ func TestLlamaEncoder_Encode(t *testing.T) {

func TestLlamaTwoEncoder_Encode(t *testing.T) {
testString := "The fox jumped over the hare.\nThe turtle is faster than the hare."
llamaTokens := Llama2Encoder.Encode(&testString)
llamaTokens := llama2Encoder.Encode(&testString)
assert.Equal(t, llamaTokens, &Tokens{1576, 1701, 29916, 12500, 287, 975, 278, 447, 276, 29889, 13, 1576, 260, 4227, 280, 338, 8473, 1135, 278, 447, 276, 29889})
}

func TestLlamaTwoTokenizerDecode(t *testing.T) {
outputString := "<s>The fox jumped over the hare.\nThe turtle is faster than the hare."
llamaTokens := Tokens{1, 1576, 1701, 29916, 12500, 287, 975, 278, 447, 276, 29889, 13, 1576, 260, 4227, 280, 338, 8473, 1135, 278, 447, 276, 29889}
output := Llama2Encoder.Decode(&llamaTokens)
output := llama2Encoder.Decode(&llamaTokens)
assert.Equal(t, outputString, output)
}

func TestLlamaTwoEncodeDecode(t *testing.T) {
testString := "The fox jumped over the hare.\nThe turtle is faster than the hare."
outputString := "The fox jumped over the hare.\nThe turtle is faster than the hare."
llamaTokens := Llama2Encoder.Encode(&testString)
output := Llama2Encoder.Decode(llamaTokens)
llamaTokens := llama2Encoder.Encode(&testString)
output := llama2Encoder.Decode(llamaTokens)
assert.Equal(t, outputString, output)
}

func TestMistralEncoder_Encode(t *testing.T) {
testString := "The fox jumped over the hare.\nThe turtle is faster than the hare."
mistralTokens := mistralEncoder.Encode(&testString)
fmt.Printf("mistralTokens: %v\n", mistralTokens)
assert.Equal(t, mistralTokens, &Tokens{1, 415, 285, 1142, 14949, 754, 272, 295, 492, 28723, 13, 1014, 261, 3525, 291, 349, 9556, 821, 272, 295, 492, 28723})
}

func TestMistralTokenizerDecode(t *testing.T) {
outputString := "<s> The fox jumped over the hare.\nThe turtle is faster than the hare."
mistralTokens := Tokens{1, 415, 285, 1142, 14949, 754, 272, 295, 492, 28723, 13, 1014, 261, 3525, 291, 349, 9556, 821, 272, 295, 492, 28723}
output := mistralEncoder.Decode(&mistralTokens)
assert.Equal(t, outputString, output)
}

func TestMistralEncodeDecode(t *testing.T) {
testString := "The fox jumped over the hare.\nThe turtle is faster than the hare."
outputString := "<s> The fox jumped over the hare.\nThe turtle is faster than the hare."
mistralTokens := mistralEncoder.Encode(&testString)
output := mistralEncoder.Decode(mistralTokens)
assert.Equal(t, outputString, output)
}

func TestMistralEncodeDecodeFrankenstein(t *testing.T) {
frankensteinCorpus := "resources/frankenstein.txt"
frankensteinText, err := os.ReadFile(frankensteinCorpus)
if err != nil {
t.Errorf("Error reading Frankenstein corpus: %v", err)
}
frankensteinString := string(frankensteinText)
mistralTokens := mistralEncoder.Encode(&frankensteinString)
output := mistralEncoder.Decode(mistralTokens)
assert.Equal(t, "<s> "+frankensteinString, output)
}

func TestReadTokenizerConfig(t *testing.T) {
fmt.Println("Testing ReadTokenizerConfig")
// json with eos, bos, pad as strings
jsonStr := `{"eos_token": "TC", "bos_token": "TD", "pad_token": "TE"}` //cooresponds to 6669, 10989, 5428 in pythia vocab

//download filler model
modelId := "EleutherAI/pythia-70m"
destPath := "./TestReadTokenizerConfig"
destPathPTR := &destPath
defer os.RemoveAll(destPath)
var rsrcType resources.ResourceType
rsrcType = resources.RESOURCETYPE_TRANSFORMERS
hfApiToken := os.Getenv("HF_API_TOKEN")
os.MkdirAll(destPath, 0755)
_, rsrcErr := resources.ResolveResources(modelId, destPathPTR,
resources.RESOURCE_MODEL, rsrcType, hfApiToken)
if rsrcErr != nil {
os.RemoveAll(destPath)
t.Errorf("Error downloading model resources: %s", rsrcErr)
}

// replace tokenizer_config.json with jsonStr
tokenizerConfigPath := destPath + "/tokenizer_config.json"
err := os.WriteFile(tokenizerConfigPath, []byte(jsonStr), 0644)
if err != nil {
t.Errorf("Error writing to tokenizer_config.json: %v", err)
}

// read tokenizer config by encoding a string
encoder, err := NewEncoder(destPath)
if err != nil {
t.Errorf("Error creating encoder: %v", err)
}

// check that the tokens are correct
assert.Equal(t, encoder.EosToken, Token(6669))
assert.Equal(t, encoder.BosToken, Token(10989))
assert.Equal(t, encoder.PadToken, Token(5428))

// Clean up by removing the downloaded folder
fmt.Println("All Exists - Looks good.")
}

func TestGPTDecoder_Decode(t *testing.T) {
// TBD
}
Expand All @@ -823,10 +903,10 @@ func TestModelDownload(t *testing.T) {
rsrcType = resources.RESOURCETYPE_TRANSFORMERS
hfApiToken := os.Getenv("HF_API_TOKEN")
os.MkdirAll(destPath, 0755)
defer os.RemoveAll(destPath)
_, rsrcErr := resources.ResolveResources(modelId, destPathPTR,
resources.RESOURCE_MODEL, rsrcType, hfApiToken)
if rsrcErr != nil {
os.RemoveAll(destPath)
t.Errorf("Error downloading model resources: %s", rsrcErr)
}

Expand All @@ -841,9 +921,11 @@ func TestModelDownload(t *testing.T) {
fmt.Println("config.json exists")

} else if errors.Is(err, os.ErrNotExist) {
os.RemoveAll(destPath)
t.Errorf("config.json does not exist")

} else {
os.RemoveAll(destPath)
t.Errorf("Error checking for config.json")
}

Expand All @@ -853,9 +935,11 @@ func TestModelDownload(t *testing.T) {
fmt.Println("pytorch_model.bin exists")

} else if errors.Is(err, os.ErrNotExist) {
os.RemoveAll(destPath)
t.Errorf("pytorch_model.bin does not exist")

} else {
os.RemoveAll(destPath)
t.Errorf("Error checking for pytorch_model.bin")
}

Expand All @@ -865,9 +949,11 @@ func TestModelDownload(t *testing.T) {
fmt.Println("tokenizer.json exists")

} else if errors.Is(err, os.ErrNotExist) {
os.RemoveAll(destPath)
t.Errorf("tokenizer.json does not exist")

} else {
os.RemoveAll(destPath)
t.Errorf("Error checking for tokenizer.json")
}

Expand All @@ -877,13 +963,16 @@ func TestModelDownload(t *testing.T) {
fmt.Println("vocab.json exists")

} else if errors.Is(err, os.ErrNotExist) {
os.RemoveAll(destPath)
t.Errorf("vocab.json does not exist")

} else {
os.RemoveAll(destPath)
t.Errorf("Error checking for vocab.json")
}

// Finish Test - Deferred removal of the downloaded folder
// Clean up by removing the downloaded folder
os.RemoveAll(destPath)
fmt.Println("All Exists - Looks good.")
}

Expand All @@ -900,10 +989,10 @@ func TestModelDownloadPythia(t *testing.T) {
rsrcType = resources.RESOURCETYPE_TRANSFORMERS
hfApiToken := os.Getenv("HF_API_TOKEN")
os.MkdirAll(destPath, 0755)
defer os.RemoveAll(destPath)
_, rsrcErr := resources.ResolveResources(modelId, destPathPTR,
resources.RESOURCE_MODEL, rsrcType, hfApiToken)
if rsrcErr != nil {
os.RemoveAll(destPath)
t.Errorf("Error downloading model resources: %s", rsrcErr)
}

Expand All @@ -918,9 +1007,11 @@ func TestModelDownloadPythia(t *testing.T) {
fmt.Println("config.json exists")

} else if errors.Is(err, os.ErrNotExist) {
os.RemoveAll(destPath)
t.Errorf("config.json does not exist")

} else {
os.RemoveAll(destPath)
t.Errorf("Error checking for config.json")
}

Expand All @@ -930,9 +1021,11 @@ func TestModelDownloadPythia(t *testing.T) {
fmt.Println("pytorch_model.bin exists")

} else if errors.Is(err, os.ErrNotExist) {
os.RemoveAll(destPath)
t.Errorf("pytorch_model.bin does not exist")

} else {
os.RemoveAll(destPath)
t.Errorf("Error checking for pytorch_model.bin")
}

Expand All @@ -942,9 +1035,11 @@ func TestModelDownloadPythia(t *testing.T) {
fmt.Println("tokenizer.json exists")

} else if errors.Is(err, os.ErrNotExist) {
os.RemoveAll(destPath)
t.Errorf("tokenizer.json does not exist")

} else {
os.RemoveAll(destPath)
t.Errorf("Error checking for tokenizer.json")
}

Expand All @@ -954,13 +1049,16 @@ func TestModelDownloadPythia(t *testing.T) {
fmt.Println("vocab.json exists")

} else if errors.Is(err, os.ErrNotExist) {
os.RemoveAll(destPath)
t.Errorf("vocab.json does not exist")

} else {
os.RemoveAll(destPath)
t.Errorf("Error checking for vocab.json")
}

// Finish Test - Deferred removal of the downloaded folder
// Clean up by removing the downloaded folder
os.RemoveAll(destPath)
fmt.Println("All Exists - Looks good.")
}

Expand All @@ -976,10 +1074,10 @@ func TestModelDownloadPythiaSharded(t *testing.T) {
rsrcType = resources.RESOURCETYPE_TRANSFORMERS
hfApiToken := os.Getenv("HF_API_TOKEN")
os.MkdirAll(destPath, 0755)
defer os.RemoveAll(destPath)
_, rsrcErr := resources.ResolveResources(modelId, destPathPTR,
resources.RESOURCE_MODEL, rsrcType, hfApiToken)
if rsrcErr != nil {
os.RemoveAll(destPath)
t.Errorf("Error downloading model resources: %s", rsrcErr)
}

Expand All @@ -994,9 +1092,11 @@ func TestModelDownloadPythiaSharded(t *testing.T) {
fmt.Println("pytorch_model-00001-of-00002.bin exists")

} else if errors.Is(err, os.ErrNotExist) {
os.RemoveAll(destPath)
t.Errorf("pytorch_model-00001-of-00002.bin does not exist")

} else {
os.RemoveAll(destPath)
t.Errorf("Error checking for pytorch_model-00001-of-00002.bin")
}

Expand All @@ -1006,9 +1106,11 @@ func TestModelDownloadPythiaSharded(t *testing.T) {
fmt.Println("pytorch_model-00002-of-00002.bin exists")

} else if errors.Is(err, os.ErrNotExist) {
os.RemoveAll(destPath)
t.Errorf("pytorch_model-00002-of-00002.bin does not exist")

} else {
os.RemoveAll(destPath)
t.Errorf("Error checking for pytorch_model-00002-of-00002.bin")
}

Expand All @@ -1018,13 +1120,16 @@ func TestModelDownloadPythiaSharded(t *testing.T) {
fmt.Println("pytorch_model.bin.index.json exists")

} else if errors.Is(err, os.ErrNotExist) {
os.RemoveAll(destPath)
t.Errorf("pytorch_model.bin.index.json does not exist")

} else {
os.RemoveAll(destPath)
t.Errorf("Error checking for pytorch_model.bin.index.json")
}

// Finish Test - Deferred removal of the downloaded folder
// Clean up by removing the downloaded folder
os.RemoveAll(destPath)
fmt.Println("All Exists - Looks good.")

}
Expand Down Expand Up @@ -1137,10 +1242,10 @@ func TestModelDownloadFairseq(t *testing.T) {
rsrcType = resources.RESOURCETYPE_TRANSFORMERS
hfApiToken := os.Getenv("HF_API_TOKEN")
os.MkdirAll(destPath, 0755)
defer os.RemoveAll(destPath)
_, rsrcErr := resources.ResolveResources(modelId, destPathPTR,
resources.RESOURCE_MODEL, rsrcType, hfApiToken)
if rsrcErr != nil {
os.RemoveAll(destPath)
t.Errorf("Error downloading model resources: %s", rsrcErr)
}

Expand All @@ -1154,9 +1259,11 @@ func TestModelDownloadFairseq(t *testing.T) {
fmt.Println("config.json exists")

} else if errors.Is(err, os.ErrNotExist) {
os.RemoveAll(destPath)
t.Errorf("config.json does not exist")

} else {
os.RemoveAll(destPath)
t.Errorf("Error checking for config.json")
}

Expand All @@ -1166,9 +1273,11 @@ func TestModelDownloadFairseq(t *testing.T) {
fmt.Println("pytorch_model.bin exists")

} else if errors.Is(err, os.ErrNotExist) {
os.RemoveAll(destPath)
t.Errorf("pytorch_model.bin does not exist")

} else {
os.RemoveAll(destPath)
t.Errorf("Error checking for pytorch_model.bin")
}

Expand All @@ -1178,9 +1287,11 @@ func TestModelDownloadFairseq(t *testing.T) {
fmt.Println("vocab.json exists")

} else if errors.Is(err, os.ErrNotExist) {
os.RemoveAll(destPath)
t.Errorf("vocab.json does not exist")

} else {
os.RemoveAll(destPath)
t.Errorf("Error checking for vocab.json")
}

Expand All @@ -1190,12 +1301,15 @@ func TestModelDownloadFairseq(t *testing.T) {
fmt.Println("merges.txt exists")

} else if errors.Is(err, os.ErrNotExist) {
os.RemoveAll(destPath)
t.Errorf("merges.txt does not exist")

} else {
os.RemoveAll(destPath)
t.Errorf("Error checking for merges.txt")
}

// Finish Test - Deferred removal of the downloaded folder
// Clean up by removing the downloaded folder
os.RemoveAll(destPath)
fmt.Println("All Exists - Looks good (Fairseq Download).")
}

0 comments on commit 57d55c4

Please sign in to comment.