Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Control flow support #124

Open
wants to merge 13 commits into
base: master
Choose a base branch
from

Conversation

HeydrichBeillschmidt
Copy link

[WIP] support tuple-shaped parameters for while instruction

@tdietert
Copy link

Hi @HeydrichBeillschmidt, when I merge your changes into my fork and try to call run_auto_sharding_pass on a simple MNIST model, I get this error:

  File "/workspaces/alpa/alpa/shard_parallel/auto_sharding.py", line 355, in run_auto_sharding_pass
    xe.run_auto_sharding(hlo_module, compile_options)
IndexError: absl::container_internal::raw_hash_map<>::at

The source of the error is the CreateStrategyVector code, where apparently a select operation has not been added to the strategy_map, and thus results in an error when iterating through the operands of the dot.278 instruction. Below is some HLO that comes from an intermediate stage of compilation, after the spmd_simplify pipeline, and before the spmd_pipeline that runs the auto sharding pass:

  broadcast.6 = f32[2048,1600]{1,0} broadcast(constant.171), dimensions={}
  select = f32[2048,1600]{1,0} select(compare.183, reshape.29, broadcast.6), metadata={op_type="Mul" op_name="mnist/sequential/dropout/dropout/Mul_1" source_file="/home/vscode/.local/lib/python3.8/site-packages/keras/backend.py" source_line=1940}
  arg34.35 = f32[1600,10]{1,0} parameter(34), parameter_replication={false}, metadata={op_name="XLA_Args"}
  dot.268 = f32[2048,10]{1,0} dot(select, arg34.35), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_type="MatMul" op_name="mnist/sequential/dense/MatMul" source_file="/home/vscode/.local/lib/python3.8/site-packages/keras/layers/core/dense.py" source_line=221}

And finally, here is some logging output I've generated that shows the sequence of events leading up to this failed indexing into the strategy map:

HandleDot[0]: dot.268
CreateLeafStrategyVector: dot.268
Potential Failing operand instruction: %select = f32[2048,1600]{1,0} select(pred[2048,1600]{1,0} %compare.183, f32[2048,1600]{1,0} %reshape.29, f32[2048,1600]{1,0} %broadcast.6), metadata={op_type="Mul" op_name="mnist/sequential/dropout/dropout/Mul_1" source_file="/home/vscode/.local/lib/python3.8/site-packages/keras/backend.py" source_line=1940}

Do you have any idea what could be the problem?

@tdietert
Copy link

@HeydrichBeillschmidt I've solved this problem by undoing the part of the diff where you build an instruction sequence from the entry_computation->instructions() list. You passed this entry_sequence value to BuildStrategyAndCost, instead of the sequence value constructed from the hlo_live_range, but it doesn't actually contain all the instructions in the computation: https://github.com/alpa-projects/tensorflow-alpa/pull/124/files#diff-83aa23c5123bde398bcd2002e8bf5d5bdf79341e11f461715a127f9547357a13R2806

Is there a reason you did this? Replacing entry_sequence with sequence (from the hlo_live_range value, like in the master branch) passed to BuildStrategyAndCost solved my issue.

@HeydrichBeillschmidt
Copy link
Author

@HeydrichBeillschmidt I've solved this problem by undoing the part of the diff where you build an instruction sequence from the entry_computation->instructions() list. You passed this entry_sequence value to BuildStrategyAndCost, instead of the sequence value constructed from the hlo_live_range, but it doesn't actually contain all the instructions in the computation: https://github.com/alpa-projects/tensorflow-alpa/pull/124/files#diff-83aa23c5123bde398bcd2002e8bf5d5bdf79341e11f461715a127f9547357a13R2806

Is there a reason you did this? Replacing entry_sequence with sequence (from the hlo_live_range value, like in the master branch) passed to BuildStrategyAndCost solved my issue.

Hi @tdietert , thank you for your issue. The BuildStrategyAndCost is designed as a recursive structure, and entry_sequence here is passed for avoiding repeated construction for instructions in computations such as while body. However, simply letting entry_sequence = entry_computation->instructions() was incorrect. The problem is addressed in the latest commit.

@tdietert
Copy link

tdietert commented Sep 2, 2022

@HeydrichBeillschmidt Thanks for your response! We have tried your latest changes and they work well for us, thank you. We have not validated the output, that the while loops are parallelized "correctly", but we don't experience any of the issues we experienced before.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants