Skip to content

Commit

Permalink
Update torchao to 0.5.0 and fix GPU quantization tutorial (#3069)
Browse files Browse the repository at this point in the history
* Update torchao to 0.5.0 and fix GPU quantization tutorial
---------

Co-authored-by: HDCharles <[email protected]>
  • Loading branch information
svekars and HDCharles authored Oct 1, 2024
1 parent 8d959ca commit 54b62c1
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 10 deletions.
2 changes: 1 addition & 1 deletion .ci/docker/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -68,5 +68,5 @@ iopath
pygame==2.6.0
pycocotools
semilearn==0.3.2
torchao==0.0.3
torchao==0.5.0
segment_anything==1.0
31 changes: 22 additions & 9 deletions prototype_source/gpu_quantization_torchao_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@
#

import torch
from torchao.quantization import change_linear_weights_to_int8_dqtensors
from torchao.quantization.quant_api import quantize_, int8_dynamic_activation_int8_weight
from torchao.utils import unwrap_tensor_subclass, TORCH_VERSION_AT_LEAST_2_5
from segment_anything import sam_model_registry
from torch.utils.benchmark import Timer

Expand Down Expand Up @@ -156,9 +157,9 @@ def get_sam_model(only_one_block=False, batchsize=1):
# in memory bound situations where the benefit comes from loading less
# weight data, rather than doing less computation. The torchao APIs:
#
# ``change_linear_weights_to_int8_dqtensors``,
# ``change_linear_weights_to_int8_woqtensors`` or
# ``change_linear_weights_to_int4_woqtensors``
# ``int8_dynamic_activation_int8_weight()``,
# ``int8_weight_only()`` or
# ``int4_weight_only()``
#
# can be used to easily apply the desired quantization technique and then
# once the model is compiled with ``torch.compile`` with ``max-autotune``, quantization is
Expand All @@ -170,7 +171,7 @@ def get_sam_model(only_one_block=False, batchsize=1):
# ``apply_weight_only_int8_quant`` instead as drop in replacement for the two
# above (no replacement for int4).
#
# The difference between the two APIs is that ``change_linear_weights`` API
# The difference between the two APIs is that ``int8_dynamic_activation`` API
# alters the weight tensor of the linear module so instead of doing a
# normal linear, it does a quantized operation. This is helpful when you
# have non-standard linear ops that do more than one thing. The ``apply``
Expand All @@ -185,7 +186,10 @@ def get_sam_model(only_one_block=False, batchsize=1):
model, image = get_sam_model(only_one_block, batchsize)
model = model.to(torch.bfloat16)
image = image.to(torch.bfloat16)
change_linear_weights_to_int8_dqtensors(model)
quantize_(model, int8_dynamic_activation_int8_weight())
if not TORCH_VERSION_AT_LEAST_2_5:
# needed for subclass + compile to work on older versions of pytorch
unwrap_tensor_subclass(model)
model_c = torch.compile(model, mode='max-autotune')
quant_res = benchmark(model_c, image)
print(f"bf16 compiled runtime of the quantized block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")
Expand Down Expand Up @@ -220,7 +224,10 @@ def get_sam_model(only_one_block=False, batchsize=1):
model = model.to(torch.bfloat16)
image = image.to(torch.bfloat16)
torch._inductor.config.force_fuse_int_mm_with_mul = True
change_linear_weights_to_int8_dqtensors(model)
quantize_(model, int8_dynamic_activation_int8_weight())
if not TORCH_VERSION_AT_LEAST_2_5:
# needed for subclass + compile to work on older versions of pytorch
unwrap_tensor_subclass(model)
model_c = torch.compile(model, mode='max-autotune')
quant_res = benchmark(model_c, image)
print(f"bf16 compiled runtime of the fused quantized block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")
Expand Down Expand Up @@ -251,7 +258,10 @@ def get_sam_model(only_one_block=False, batchsize=1):
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.coordinate_descent_check_all_directions = True
torch._inductor.config.force_fuse_int_mm_with_mul = True
change_linear_weights_to_int8_dqtensors(model)
quantize_(model, int8_dynamic_activation_int8_weight())
if not TORCH_VERSION_AT_LEAST_2_5:
# needed for subclass + compile to work on older versions of pytorch
unwrap_tensor_subclass(model)
model_c = torch.compile(model, mode='max-autotune')
quant_res = benchmark(model_c, image)
print(f"bf16 compiled runtime of the final quantized block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")
Expand Down Expand Up @@ -280,7 +290,10 @@ def get_sam_model(only_one_block=False, batchsize=1):
model, image = get_sam_model(False, batchsize)
model = model.to(torch.bfloat16)
image = image.to(torch.bfloat16)
change_linear_weights_to_int8_dqtensors(model)
quantize_(model, int8_dynamic_activation_int8_weight())
if not TORCH_VERSION_AT_LEAST_2_5:
# needed for subclass + compile to work on older versions of pytorch
unwrap_tensor_subclass(model)
model_c = torch.compile(model, mode='max-autotune')
quant_res = benchmark(model_c, image)
print(f"bf16 compiled runtime of the quantized full model is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")
Expand Down

0 comments on commit 54b62c1

Please sign in to comment.