Skip to content

Commit

Permalink
[aoti] Remove need for -l in cmake call
Browse files Browse the repository at this point in the history
  • Loading branch information
angelayi committed Sep 17, 2024
1 parent f285434 commit 23f96d4
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 37 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ torchchat/utils/scripts/build_native.sh aoti

Then run the compiled executable, with the pt2.
```bash
cmake-out/aoti_run exportedModels/llama3_1_artifacts.pt2 -z `python3 torchchat.py where llama3.1`/tokenizer.model -l 3 -i "Once upon a time"
cmake-out/aoti_run exportedModels/llama3_1_artifacts.pt2 -z `python3 torchchat.py where llama3.1`/tokenizer.model -i "Once upon a time"
```

## Mobile Execution
Expand Down
59 changes: 25 additions & 34 deletions runner/run.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ LICENSE file in the root directory of this source tree.

#ifdef __AOTI_MODEL__
#include <torch/csrc/inductor/aoti_package/model_package_loader.h>
torch::Device aoti_device(torch::kCPU);

#else // __ET_MODEL__
#include <executorch/extension/module/module.h>
#include <executorch/extension/runner_util/managed_tensor.h>
Expand Down Expand Up @@ -88,9 +86,11 @@ typedef struct {
typedef struct {
Config config; // the hyperparameters of the architecture (the blueprint)
RunState state; // buffers for the "wave" of activations in the forward pass
std::unordered_map<std::string, std::string> metadata;

#ifdef __AOTI_MODEL__
torch::inductor::AOTIModelPackageLoader* runner;

#else // __ET_MODEL__
Module* runner;
#endif
Expand Down Expand Up @@ -129,19 +129,9 @@ void read_checkpoint(char* checkpoint, Config* config) {

void build_transformer(
Transformer* t,
char* model_path,
int vocab_size,
int seq_len) {
// read in the Config and the Weights from the model
// read_checkpoint(model_path, &t->config);
// allocate the RunState buffers
t->config.vocab_size = vocab_size;
t->config.seq_len = seq_len;
malloc_run_state(&t->state, &t->config);

char* model_path) {
#ifdef __AOTI_MODEL__
t->runner = new torch::inductor::AOTIModelPackageLoader(model_path);
aoti_device = t->runner->get_metadata()["AOTI_DEVICE_KEY"] == "cpu" ? torch::Device(torch::kCPU) : torch::Device(torch::kCUDA);
#else //__ET_MODEL__
t->runner = new Module(
/* path to PTE model */ model_path,
Expand Down Expand Up @@ -193,6 +183,9 @@ float* forward(Transformer* transformer, int token, int pos) {
torch::Tensor token_tensor =
torch::from_blob(token_buffer, {1, 1}, torch::kLong);
torch::Tensor pos_tensor = torch::from_blob(pos_buffer, {1}, torch::kLong);
torch::Device aoti_device = transformer->runner->get_metadata()["AOTI_DEVICE_KEY"] == "cpu"
? torch::Device(torch::kCPU)
: torch::Device(torch::kCUDA);
std::vector<torch::Tensor> inputs{
token_tensor.to(aoti_device), pos_tensor.to(aoti_device)};

Expand Down Expand Up @@ -880,26 +873,25 @@ int main(int argc, char* argv[]) {
system_prompt = argv[i + 1];
} else if (argv[i][1] == 'l') {
llama_ver = atoi(argv[i + 1]);
#ifdef __AOTI_MODEL__
} else if (argv[i][1] == 'd') {
#ifdef USE_CUDA
if (strcasecmp(argv[i + 1], "CUDA") == 0) {
aoti_device = torch::Device(torch::kCUDA);
} else
#endif
if (strcasecmp(argv[i + 1], "CPU") == 0) {
aoti_device = torch::Device(torch::kCPU);
} else {
fprintf(stderr, "Unknown device %s", argv[i + 1]);
exit(1);
}
#endif
} else {
error_usage();
}
}

if (model_path == NULL) {
fprintf(stderr, "No model_path provided.");
error_usage();
}

Transformer transformer;
build_transformer(&transformer, model_path);

#ifdef __AOTI_MODEL__
ModelType model_type = get_model_type(std::stoi(transformer.runner->get_metadata()["tokenizer_type"]));
#else // __ET_MODEL__
ModelType model_type = get_model_type(llama_ver);
#endif

if (model_type == UNKNOWN_MODEL) {
fprintf(
stderr,
Expand All @@ -908,11 +900,6 @@ int main(int argc, char* argv[]) {
error_usage();
}

if (model_path == NULL) {
fprintf(stderr, "No model_path provided.");
error_usage();
}

if (tokenizer_path == NULL) {
fprintf(stderr, "No tokenizer_path provided.");
error_usage();
Expand All @@ -935,8 +922,12 @@ int main(int argc, char* argv[]) {
vocab_size = tokenizer->vocab_size();
}

Transformer transformer;
build_transformer(&transformer, model_path, vocab_size, steps);
// read in the Config and the Weights from the model
// read_checkpoint(model_path, &t->config);
// allocate the RunState buffers
transformer.config.vocab_size = vocab_size;
transformer.config.seq_len = steps;
malloc_run_state(&transformer.state, &transformer.config);

Sampler sampler;
build_sampler(&sampler, vocab_size, temperature, topp, rng_seed);
Expand Down
13 changes: 11 additions & 2 deletions torchchat/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

import os
from typing import Optional
from typing import Dict, Optional

import torch
import torch.nn as nn
Expand Down Expand Up @@ -39,6 +39,7 @@ def export_for_server(
output_path: str = "model.pt2",
dynamic_shapes: bool = False,
package: bool = True,
metadata: Optional[Dict[str, str]] = None,
) -> str:
"""
Export the model using AOT Compile to get a .dso for server use cases.
Expand Down Expand Up @@ -67,7 +68,7 @@ def export_for_server(
dynamic_shapes = None

with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):
metadata = {} # TODO: put more metadata here
metadata = metadata or {}
options = {"aot_inductor.package": package, "aot_inductor.metadata": metadata}
if not package:
options = {"aot_inductor.output_path": output_path}
Expand Down Expand Up @@ -373,6 +374,7 @@ def main(args):

# TODO: clean this up
# This mess is because ET does not support _weight_int4pack_mm right now
tokenizer_args = None
if not builder_args.gguf_path:
# tokenizer needed for quantization so get that here,
try:
Expand Down Expand Up @@ -443,11 +445,18 @@ def main(args):

if output_aoti_package_path:
output_aoti_package_path = str(os.path.abspath(output_aoti_package_path))

tokenizer_type = "0"
if tokenizer_args is not None:
tokenizer_type = "2" if tokenizer_args.is_sentencepiece else "3"

metadata = {"tokenizer_type": tokenizer_type}
print(f"Exporting model using AOT Inductor to {output_aoti_package_path}")
export_for_server(
model_to_aoti_package,
builder_args.device,
output_aoti_package_path,
builder_args.dynamic_shapes,
package=True,
metadata=metadata,
)

0 comments on commit 23f96d4

Please sign in to comment.