Skip to content

Commit

Permalink
Merge pull request #12 from coreweave/rwang.s3tokenizer.10082023
Browse files Browse the repository at this point in the history
Update dataset_tokenizer.go
  • Loading branch information
wbrown authored Oct 9, 2023
2 parents d44973d + 56201e9 commit 1bca3fc
Showing 1 changed file with 26 additions and 5 deletions.
31 changes: 26 additions & 5 deletions cmd/dataset_tokenizer/dataset_tokenizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -316,18 +316,37 @@ func fetchTextFileS3(svc S3Client, bucketName, objectKey string) (string, error)
}

// removeS3Prefix splits the input into the bucket and to ensure that s3:// is present
func removeS3Prefix(input string) (hasS3Prefix bool, remainder string) {
func removeS3Prefix(input string) (hasS3Prefix bool, remainder string, s3FilePath string) {
prefix := "s3://"
if strings.HasPrefix(input, prefix) {
return true, input[len(prefix):]
//if it is just s3:// then return empty string
if len(input) == len(prefix) {
return false, input, ""
}
//if it is s3://bucket then return bucket and empty string
if strings.Index(input[len(prefix):], "/") == -1 {
return true, input[len(prefix):], ""
}
//if it is s3://bucket/ then return bucket and empty string
if strings.Index(input[len(prefix):], "/") == len(input)-len(prefix)-1 {
return true, input[len(prefix) : len(input)-1], ""
}
//if it is s3://bucket/path then return bucket and path
if strings.Index(input[len(prefix):], "/") > 0 {
idxOfFirstSlash := strings.Index(input[len(prefix):], "/")
bucket := input[len(prefix) : len(prefix)+idxOfFirstSlash]
pathOfFile := input[len(prefix)+idxOfFirstSlash+1:]
return true, bucket, pathOfFile
}
}
return false, input
return false, input, ""
}

// ReadTextsFromS3 reads text files recursively from all prefixes in an S3 bucket.
func ReadTextsFromS3(
svc S3Client,
bucketName string,
s3FilePath string,
sanitize bool,
numReaderThreads int,
) (chan namedRuneReader, error) {
Expand All @@ -342,6 +361,7 @@ func ReadTextsFromS3(
if !ok {
break
}
if s3FilePath == "" || strings.HasPrefix(*object.Key, s3FilePath) {

if strings.HasSuffix(*object.Key, ".jsonl") {
// Handle JSONL files.
Expand Down Expand Up @@ -385,6 +405,7 @@ func ReadTextsFromS3(
}
}
}
}
}
wg.Done()
}
Expand Down Expand Up @@ -1237,7 +1258,7 @@ func main() {
log.Fatal(tokErr)
}

hasS3Prefix, s3Bucket := removeS3Prefix(*inputDir)
hasS3Prefix, s3Bucket, s3FilePath := removeS3Prefix(*inputDir)

if hasS3Prefix && *s3Endpoint == "" {
flag.Usage()
Expand Down Expand Up @@ -1267,7 +1288,7 @@ func main() {
}))

svc := s3.New(sess)
textReaders, err = ReadTextsFromS3(svc, s3Bucket, *sanitizeBool, *numReaderThreads)
textReaders, err = ReadTextsFromS3(svc, s3Bucket, s3FilePath, *sanitizeBool, *numReaderThreads)

if err != nil {
log.Fatal(err)
Expand Down

0 comments on commit 1bca3fc

Please sign in to comment.