Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

unhashable type: non-nested SymInt #1381

Open
bhack opened this issue Dec 4, 2024 · 32 comments
Open

unhashable type: non-nested SymInt #1381

bhack opened this issue Dec 4, 2024 · 32 comments
Assignees
Labels
autoquant bug Something isn't working

Comments

@bhack
Copy link

bhack commented Dec 4, 2024

With pytorch and autoquant nighties:

    model = autoquant(
        model, 
    )
  File "/opt/conda/lib/python3.11/site-packages/torchao/quantization/autoquant.py", line 1140, in autoquant_prehook
    real_model.forward(*args, **kwargs)
    ....
    features = self.backbone(inputs)
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
...
    x = blk(x)
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/modeling/backbone/vit.py", line 362, in forward
    x = self.attn(x)
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
....
    qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, self.head_dim).permute(0, 3, 2, 1, 4).contiguous()
@bhack
Copy link
Author

bhack commented Dec 4, 2024

pytorch/pytorch#135099

@bhack
Copy link
Author

bhack commented Dec 5, 2024

@bdhirsh @jerryzh168 Do you think it is a duplicate of pytorch/pytorch#136287?

@jerryzh168
Copy link
Contributor

@bhack can you try with pytorch nightly as well

pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu124

@bhack
Copy link
Author

bhack commented Dec 5, 2024

It was already on pytorch nightly and ao nightly.

@jerryzh168
Copy link
Contributor

I see, in that case can you give us a minimal repro for the issue

@bhack
Copy link
Author

bhack commented Dec 5, 2024

It is a quite large model that I cannot share. Can I give you any debug element?

@jerryzh168
Copy link
Contributor

I see, maybe @bdhirsh can provide some pointers, looks like this is related to torch.compile

@bhack
Copy link
Author

bhack commented Dec 5, 2024

Ok, in the mean time I was seeing some example in SAM2 server for ao + AOTI. As in this case I am trying to use both what is in general the best practice to use ao with AOTI?

@jerryzh168
Copy link
Contributor

jerryzh168 commented Dec 5, 2024

@bhack
Copy link
Author

bhack commented Dec 5, 2024

Thanks I see the example isn't with autoquant. Isn't currently going to play well autoquant with AOTI?

@bhack
Copy link
Author

bhack commented Dec 5, 2024

If this could help I have this error with

autoquant(model, qtensor_class_list=DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST,min_sqnr=40)

Instead with

quantize_(model, int8_dynamic_activation_int8_weight())
model = unwrap_tensor_subclass(model)

I got

cannot mutate tensors with frozen storage

While executing %add_ : [num_users=1] = call_function[target=torch.ops.aten.add_.Tensor](args = (%to_2, %backbone_blocks_0_attn_qkv_bias), kwargs = {})
.....
 qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)

@bhack
Copy link
Author

bhack commented Dec 5, 2024

For the 2nd case where we had cannot mutate tensors with frozen storage I've moved the error removing the .to from the input example allocated on GPU directly.

But now we have the same error cannot mutate tensors with frozen storage from official timm

While executing %add__24 : [num_users=1] = call_function[target=torch.ops.aten.add_.Tensor](args = (%to_26, %backbone_blocks_0_mlp_fc2_bias), kwargs = {})
Original traceback:
....
  File "/opt/conda/lib/python3.11/site-packages/timm/layers/mlp.py", line 46, in forward
    x = self.fc2(x)

@bhack
Copy link
Author

bhack commented Dec 5, 2024

Let me know if I can debug this more both quant and autoquant errors.

@bdhirsh
Copy link
Contributor

bdhirsh commented Dec 5, 2024

I'll let Jerry comment on the autoquant + AOTI support. for that error cannot mutate tensors with frozen storage error, I think the export team still wants to fix this, although if you end up hitting it you can usually work around by sprinkling in a clone() at the right place (usually right before the mutation referenced in the stack trace, e.g. that aten.add_).

@bhack
Copy link
Author

bhack commented Dec 5, 2024

@bdhirsh Right but for the mutation case I really cannot track the source point for the clone. I've already added the clone in many places but it seems that it is always failing with the same aten.add_. how can I find the exact source point in this complex model?

@bdhirsh
Copy link
Contributor

bdhirsh commented Dec 5, 2024

what is the current stacktrace that you get? hopefully it points to somewhere useful in the user python stack closer to where the mutation was (if not, if you have some repro code I can take a look too. Although ideally we can make the user stack situation better)

@bhack
Copy link
Author

bhack commented Dec 5, 2024

I don't think the stack is a lot useful

While executing %add_ : [num_users=1] = call_function[target=torch.ops.aten.add_.Tensor](args = (%to_2, %backbone_blocks_0_attn_qkv_bias), kwargs = {})
Original traceback:
  features = self.backbone(inputs)
  File "/workspace/modeling/backbone/vit.py", line 420, in forward
    x = blk(x)
  File "/workspace/modeling/backbone/vit.py", line 279, in forward
    x = self.attn(x)
  File "/workspace/modeling/backbone/vit.py", line 70, in forward
    qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)

@bdhirsh
Copy link
Contributor

bdhirsh commented Dec 5, 2024

yeah that stacktrace alone isn't very helpful, although it doesn't look complete. If you're able to provide the entire stacktrace, or links to the backbone/vit.py code, or a repro, that would help a bunch

@bhack
Copy link
Author

bhack commented Dec 5, 2024

@bdhirsh
Copy link
Contributor

bdhirsh commented Dec 5, 2024

I didn't see an obvious source of the inplace add call either. Although I also tried instantiating a ViT() instance and exporting it locally from that repo, and that didn't error for me:

# I just pasted this in locally at the bottom of `backbone/vit.py`
m = ViT()
x = torch.randn(4, 3, 128, 128)
m = torch.export.export(m, (x,))
m2 = m.run_decompositions()
print(m2)

@bhack
Copy link
Author

bhack commented Dec 5, 2024

export itself was always working. This part of the issue appears only with extra quantize before the export.

quantize_(model, int8_dynamic_activation_int8_weight())
model = unwrap_tensor_subclass(model)

@jerryzh168
Copy link
Contributor

Thanks I see the example isn't with autoquant. Isn't currently going to play well autoquant with AOTI?

I haven't tested the serialization for autoquant yet, last time I tried there seems to be some issues with performance, will probably debug a bit later

@bhack
Copy link
Author

bhack commented Dec 6, 2024

@jerryzh168 Do you have any hint on where we are injecting this add?

cannot mutate tensors with frozen storage

While executing %add_ : [num_users=1] = call_function[target=torch.ops.aten.add_.Tensor](args = (%to_2, %backbone_blocks_0_attn_qkv_bias), kwargs = {})
.....
 qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)

Cause this was exported but with exactly the same code this only appear after we add quantization before the export:

quantize_(model, int8_dynamic_activation_int8_weight())
model = unwrap_tensor_subclass(model)

@bhack
Copy link
Author

bhack commented Dec 6, 2024

I see a similar mutate torch.ops.aten.add_.Tensor error with ao at pytorch/pytorch#139718

@bdhirsh
Copy link
Contributor

bdhirsh commented Dec 6, 2024

I was able to repro locally - looking around a bit, the mutation was coming from here. I patched torchao here, and with that patch I could run the export E2E: #1387.

There are two things that are still worth looking into on the export side:

(1) Ideally the error message from export should have properly pointed to that culprit code. It looks like the stack is getting lost somewhere (cc @tugsbayasgalan, @yushangdi )

(2) we should also actually fix pytorch/pytorch#127571 so no user changes are required in the first place

@bhack
Copy link
Author

bhack commented Dec 6, 2024

@bdhirsh thanks for confirming the repro.
Just an extra point:
Can you confirm that the unwrap is still required also on >=2.5 (including current nightly). Cause from ao current docs it seems not required but I have a serialisation error with this example commenting unwrap.

@bdhirsh
Copy link
Contributor

bdhirsh commented Dec 6, 2024

So the situation with that unwrap API is that:

(1) if you are using torch.compile, it should not be needed if you are >=2.5.1

(2) if you are using export, you currently need that API. @tugsbayasgalan is working on a more general-purpose change to torch.export that will automatically handle subclass parameters, so this won't be necessary some time in the near future

@bhack
Copy link
Author

bhack commented Dec 6, 2024

Yes I saw that PR. It is just that the current AO doc is ambiguous about the compile Vs export use case.

I don't know if you have the time but as you have the code now can you confirm also the autoquant(model, qtensor_class_list=DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST,min_sqnr=40) repro?

@bdhirsh
Copy link
Contributor

bdhirsh commented Dec 6, 2024

Just tried, looks like that doesn't work.

with export.export() (which uses "strict mode"), I get a dynamo error:

    raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: unregistered hook removable handle

from user code:
   File "/home/hirsheybar/local/a/pytorch/ao/torchao/quantization/autoquant.py", line 1141, in autoquant_prehook
    module.finalize_autoquant()

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

with non-strict export (export(strict=False)), I get some warnings and a crash (cc @jerryzh168):

activation_shapes: torch.Size([256, 768]), times_seen: 1
weight_shape: torch.Size([2304, 768]), dtype: torch.float32, bias_shape: torch.Size([2304])
warning: failed to autoquant AQFloat32LinearWeight for shape: (torch.Size([256, 768]), torch.Size([2304, 768]), torch.Size([2304]), torch.float32) due to Creating a new Tensor subclass AQFloat32LinearWeight but the raw Tensor object is already associated to a python object of type FakeTensor which is not a subclass of the requested type
warning: failed to autoquant AQBFloat16LinearWeight for shape: (torch.Size([256, 768]), torch.Size([2304, 768]), torch.Size([2304]), torch.float32) due to Creating a new Tensor subclass AQBFloat16LinearWeight but the raw Tensor object is already associated to a python object of type FakeTensor which is not a subclass of the requested type
warning: failed to autoquant AQFloat16LinearWeight for shape: (torch.Size([256, 768]), torch.Size([2304, 768]), torch.Size([2304]), torch.float32) due to Creating a new Tensor subclass AQFloat16LinearWeight but the raw Tensor object is already associated to a python object of type FakeTensor which is not a subclass of the requested type
best_cls=None

...

Traceback (most recent call last):
  File "/home/hirsheybar/local/a/pytorch/ViTMatte/modeling/backbone/vit.py", line 686, in <module>
    m = torch.export.export(m, (x,), strict=False)
  File "/home/hirsheybar/local/a/pytorch/torch/export/__init__.py", line 368, in export
    return _export(
  File "/home/hirsheybar/local/a/pytorch/torch/export/_trace.py", line 1031, in wrapper
    raise e
  File "/home/hirsheybar/local/a/pytorch/torch/export/_trace.py", line 1004, in wrapper
    ep = fn(*args, **kwargs)
  File "/home/hirsheybar/local/a/pytorch/torch/export/exported_program.py", line 127, in wrapper
    return fn(*args, **kwargs)
  File "/home/hirsheybar/local/a/pytorch/torch/export/_trace.py", line 1967, in _export
    return _export_for_training(
  File "/home/hirsheybar/local/a/pytorch/torch/export/_trace.py", line 1031, in wrapper
    raise e
  File "/home/hirsheybar/local/a/pytorch/torch/export/_trace.py", line 1004, in wrapper
    ep = fn(*args, **kwargs)
  File "/home/hirsheybar/local/a/pytorch/torch/export/exported_program.py", line 127, in wrapper
    return fn(*args, **kwargs)
  File "/home/hirsheybar/local/a/pytorch/torch/export/_trace.py", line 1831, in _export_for_training
    export_artifact = export_func(  # type: ignore[operator]
  File "/home/hirsheybar/local/a/pytorch/torch/export/_trace.py", line 1769, in _non_strict_export
    aten_export_artifact = _to_aten_func(  # type: ignore[operator]
  File "/home/hirsheybar/local/a/pytorch/torch/export/_trace.py", line 1561, in _export_to_aten_ir_make_fx
    gm, graph_signature = transform(_make_fx_helper)(
  File "/home/hirsheybar/local/a/pytorch/torch/export/_trace.py", line 1699, in _aot_export_non_strict
    gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)
  File "/home/hirsheybar/local/a/pytorch/torch/export/_trace.py", line 1481, in _make_fx_helper
    with ctx:
  File "/home/hirsheybar/local/a/pytorch-env/lib/python3.10/contextlib.py", line 142, in __exit__
    next(self.gen)
  File "/home/hirsheybar/local/a/pytorch/torch/_functorch/aot_autograd.py", line 1592, in _detect_attribute_assignment
    pytree.tree_map_with_path(
  File "/home/hirsheybar/local/a/pytorch/torch/utils/_pytree.py", line 1607, in tree_map_with_path
    all_keypath_leaves = keypath_leaves + [treespec.flatten_up_to(r) for r in rests]
  File "/home/hirsheybar/local/a/pytorch/torch/utils/_pytree.py", line 1607, in <listcomp>
    all_keypath_leaves = keypath_leaves + [treespec.flatten_up_to(r) for r in rests]
  File "/home/hirsheybar/local/a/pytorch/torch/utils/_pytree.py", line 798, in flatten_up_to
    self._flatten_up_to_helper(tree, subtrees)
  File "/home/hirsheybar/local/a/pytorch/torch/utils/_pytree.py", line 756, in _flatten_up_to_helper
    raise ValueError(
ValueError: Node arity mismatch; expected 5, but got 4.

@bhack
Copy link
Author

bhack commented Dec 6, 2024

Thanks, let me know if we could track these here or if we need new tickets.

@jerryzh168
Copy link
Contributor

jerryzh168 commented Dec 7, 2024

@bdhirsh for the autoqaunt error, how did you run it? you'll need to feed the model with some input before export to trigger autoquant: https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md#autoquantization

something like this:

model = autoquant(model, ...)
model(input)
model = torch.export_for_training(model)

for add_ seems best to be fixed in the export side, if it's possible. cc @yushangdi @tugsbayasgalan

@bhack
Copy link
Author

bhack commented Dec 9, 2024

@jerryzh168 The autoquant call on my test introduces
unhashable type: non-nested SymInt as in the ticket title.

@supriyar supriyar added autoquant bug Something isn't working labels Dec 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
autoquant bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants