Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

"Where is the overloaded function for torch.nn.functional.linear(aqt, original_weight_tensor, bias)? " #1397

Open
Lenan22 opened this issue Dec 10, 2024 · 2 comments

Comments

@Lenan22
Copy link

Lenan22 commented Dec 10, 2024

Here is an example

int8_dynamic_activation_int8_weight

aqt:

AffineQuantizedTensor(tensor_impl=PlainAQTTensorImpl(data=tensor([[ 5, -2, 24, ..., 17, 73, 54],
[ -30, -19, -53, ..., -9, -33, 55],
[ -7, -20, -28, ..., 47, 71, -15],
...,
[ 36, 8, 40, ..., 13, -10, 45],
[ -38, -12, 47, ..., -22, 0, -29],
[ 20, -127, 52, ..., 18, 27, -36]], dtype=torch.int8)... , scale=tensor([0.0293, 0.0233, 0.0271, 0.0234, 0.0209, 0.0227, 0.0247, 0.0328, 0.0270,
0.0215, 0.0245, 0.0209, 0.0325, 0.0232, 0.0238, 0.0267, 0.0237, 0.0202,
0.0249, 0.0239, 0.0255, 0.0246, 0.0225, 0.0288, 0.0194, 0.0215, 0.0224,
0.0210, 0.0253, 0.0189, 0.0240, 0.0228, 0.0208, 0.0211, 0.0295, 0.0275,
0.0200, 0.0250, 0.0202, 0.0269, 0.0266, 0.0203, 0.0223, 0.0246, 0.0212,
0.0217, 0.0246, 0.0203, 0.0219, 0.0237, 0.0216, 0.0191, 0.0213, 0.0227,
0.0330, 0.0194, 0.0226, 0.0162, 0.0203, 0.0284, 0.0218, 0.0208, 0.0254,
0.0220, 0.0357, 0.0288, 0.0290, 0.0235, 0.0218, 0.0188, 0.0279, 0.0232,
0.0238, 0.0195, 0.0256, 0.0255, 0.0204, 0.0198, 0.0211, 0.0219, 0.0262,
0.0253, 0.0246, 0.0177, 0.0209, 0.0216, 0.0253, 0.0261, 0.0215, 0.0257,
0.0240, 0.0197, 0.0206, 0.0270, 0.0243, 0.0218, 0.0261, 0.0350, 0.0238,
0.0243])... , zero_point=tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0.])... , _layout=PlainLayout()), block_size=[1, 200], shape=torch.Size([100, 200]), device=cpu, dtype=torch.float32, requires_grad=False)

original_weight_tensor:

AffineQuantizedTensor(tensor_impl=PlainAQTTensorImpl(data=tensor([[ 127, 0, 0, ..., 0, 0, 0],
[ 127, 0, 0, ..., 0, 0, 0],
[ 127, 0, 0, ..., 0, 0, 0],
...,
[ 47, 36, -70, ..., 49, 71, 5],
[ 117, -2, -91, ..., -112, 9, -81],
[ -67, -91, 114, ..., 51, 11, -126]], dtype=torch.int8)... , scale=tensor([7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01,
7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01,
7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01,
7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01,
7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01,
7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01,
7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01,
7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01,
7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01,
7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01,
7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01,
7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01,
7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01,
7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01,
7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01,
7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01,
7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01,
7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01,
7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01,
7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01,
7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01,
7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01,
7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01,
7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01,
7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01, 7.8431e+01,
2.3313e-02, 2.3492e-02, 2.3277e-02, 2.3458e-02, 2.3438e-02, 2.3528e-02,
2.3352e-02, 2.3522e-02, 2.3500e-02, 2.3332e-02, 2.3376e-02, 2.3481e-02,
2.3275e-02, 2.3509e-02, 2.3453e-02, 2.3460e-02, 2.3525e-02, 2.3489e-02,
2.3482e-02, 2.3436e-02, 2.3499e-02, 2.3523e-02, 2.3519e-02, 2.3320e-02,
2.3503e-02, 2.3453e-02, 2.3514e-02, 2.3496e-02, 2.3330e-02, 2.3444e-02,
2.3483e-02, 2.3428e-02, 2.3495e-02, 2.3445e-02, 2.3437e-02, 2.3505e-02,
2.3338e-02, 2.3517e-02, 2.3205e-02, 2.3469e-02, 2.3469e-02, 2.3506e-02,
2.3467e-02, 2.3497e-02, 2.3512e-02, 2.3497e-02, 2.3469e-02, 2.3511e-02,
2.3529e-02, 2.3445e-02, 2.3493e-02, 2.3527e-02, 2.3376e-02, 2.3366e-02,
2.3408e-02, 2.3410e-02, 2.3403e-02, 2.3441e-02, 2.3501e-02, 2.3426e-02,
2.3444e-02, 2.3502e-02, 2.3352e-02, 2.3501e-02, 2.3428e-02, 2.3424e-02,
2.3464e-02, 2.3414e-02, 2.3183e-02, 2.3088e-02, 2.3446e-02, 2.3220e-02,
2.3274e-02, 2.3457e-02, 2.3157e-02, 2.3419e-02, 2.3296e-02, 2.3498e-02,
2.3434e-02, 2.3407e-02, 2.3385e-02, 2.3437e-02, 2.3466e-02, 2.3503e-02,
2.3421e-02, 2.3364e-02, 2.3465e-02, 2.3410e-02, 2.3330e-02, 2.3472e-02,
2.3430e-02, 2.3522e-02, 2.3423e-02, 2.3422e-02, 2.3455e-02, 2.3503e-02,
2.3250e-02, 2.3400e-02, 2.3445e-02, 2.3399e-02, 2.3343e-02, 2.3464e-02,
2.3387e-02, 2.3443e-02, 2.3334e-02, 2.3378e-02, 2.3495e-02, 2.3394e-02,
2.3513e-02, 2.3255e-02, 2.3506e-02, 2.3516e-02, 2.3433e-02, 2.3354e-02,
2.3512e-02, 2.3358e-02, 2.3422e-02, 2.3400e-02, 2.3174e-02, 2.3437e-02,
2.3511e-02, 2.3354e-02, 2.3465e-02, 2.3322e-02, 2.3225e-02, 2.3226e-02,
2.3374e-02, 2.3380e-02, 2.3528e-02, 2.3435e-02, 2.3277e-02, 2.3491e-02,
2.3361e-02, 2.3392e-02, 2.3468e-02, 2.3253e-02, 2.3134e-02, 2.3092e-02,
2.3456e-02, 2.3519e-02, 2.3257e-02, 2.3524e-02, 2.3427e-02, 2.3493e-02,
2.3495e-02, 2.3376e-02, 2.3464e-02, 2.3408e-02, 2.3523e-02, 2.3171e-02])... , zero_point=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])... , _layout=PlainLayout()), block_size=(1, 200), shape=torch.Size([300, 200]), device=cpu, dtype=torch.float32, requires_grad=False)

The computation result of torch.nn.functional.linear(aqt, original_weight_tensor, bias) is consistent with that of torch.nn.functional.linear(aqt.dequantize(), original_weight_tensor.dequantize(), bias).

However, I cannot find the implementation of the overloaded function for torch.nn.functional.linear; where is it located?

@yiliu30
Copy link
Contributor

yiliu30 commented Dec 11, 2024

You can take a look at this function, I think.

@implements([torch.nn.functional.linear, aten.linear.default])
def _(func, types, args, kwargs):
input_tensor, weight_tensor, bias = (
args[0],
args[1],
args[2] if len(args) > 2 else None,
)
if not input_tensor.is_floating_point():
raise NotImplementedError(
f"{func} is not implemented for non floating point input"
)
# using try/except here so that we can have a general fallback when input_tensor/weight_tensor
# is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to
# make the branches easier to understand in `_quantized_linear_op`
try:
return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias)
except QuantizedLinearNotImplementedError as e:
# fallback path is only called when user did not specify a specfic quantized linear implementation with `_layout.quantized_linear_impl`
if (
isinstance(weight_tensor, AffineQuantizedTensor)
and hasattr(weight_tensor._layout, "quantized_linear_impl")
and weight_tensor._layout.quantized_linear_impl is not None
):
raise e
if isinstance(input_tensor, AffineQuantizedTensor):
input_tensor = input_tensor.dequantize()
if isinstance(weight_tensor, AffineQuantizedTensor):
weight_tensor = weight_tensor.dequantize()
return torch.nn.functional.linear(input_tensor, weight_tensor, bias)

@jerryzh168
Copy link
Contributor

jerryzh168 commented Dec 11, 2024

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants