Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoweringContext] Support explicit device data parameters for scalar inputs #8414

Open
rpsilva-aws opened this issue Nov 25, 2024 · 0 comments
Assignees

Comments

@rpsilva-aws
Copy link
Contributor

🚀 Feature

Extend the lowering context to handle 0 and 1 value scalars as device data parameters. As it stands, IsSpecialScalar, by default, captures 0 and 1 input values as ScalarOps IR values instead of device data when lowering the output ops.

Motivation

At the moment, these scalar values will be treated as IR constants and be inlined in the computation. At the same time, users will be no control over recompilations when re-using the same computation with the mutable scalar parameter. This is particularly prominent when generating the body computation for while loops.
In addition, not having the proper scalar parameters present may cause unexpected signatures in the input/output, particularly when these are expected or enforced elsewhere, namely OpenXLA's while loop input/output requirements for the body/cond.

In some cases, the parameter is missed out, or the provided parameter is a no-op, since there is an inlined constant in the computation [1].

Pitch

The user should be able to explicitly enforce to only inline constants, and/or only inline constants that are not given as a parameter.

[1]:

diff --git a/test/test_while_loop.py b/test/test_while_loop.py
index e8ea617b0..ee5254dda 100644
--- a/test/test_while_loop.py
+++ b/test/test_while_loop.py
@@ -84,7 +84,7 @@ class WhileLoopTest(unittest.TestCase):
     linear_model = SimpleLinear()
     linear_model.to(device)
     l_in_0 = torch.randn(2, 2, dtype=torch.float32, device=device)
-    iteri = torch.tensor(10, dtype=torch.int32, device=device)
+    iteri = torch.tensor(1, dtype=torch.int32, device=device)
     _, res_with_loop = linear_model(iteri, l_in_0)
     _, res_without_loop = linear_model.forward_without_while_loop_op(
         iteri, l_in_0)
  • Body
HloModule PyLoweringContext.18.25, entry_computation_layout={((f32[2]{0}, f32[2,2]{1,0}, f32[2,2]{1,0}, f32[2,2]{1,0}))->(s32[], f32[2]{0}, f32[2,2]{1,0}, f32[2,2]{1,0})}

%PyLoweringContext.6 (p0.12: f32[2], p1.13: f32[2,2], p2.15: f32[2,2], UnusedArgumentsPlaceholder.22: f32[2,2]) -> (s32[], f32[2], f32[2,2], f32[2,2]) {
  %UnusedArgumentsPlaceholder.22 = f32[2,2]{1,0} parameter(3)
  %constant.9 = s32[] constant(1), metadata={op_type="prim__Constant" op_name="prim__Constant" source_file="/workspaces/pytorch/xla/test/test_while_loop.py" source_line=89}
  %constant.8 = s32[] constant(1), metadata={op_type="prim__Constant" op_name="prim__Constant" source_file="<eval_with_key>.1" source_line=5}
  %constant.7 = s32[] constant(1), metadata={op_type="prim__Constant" op_name="prim__Constant" source_file="<eval_with_key>.1" source_line=5}
  %multiply.10 = s32[] multiply(s32[] %constant.8, s32[] %constant.7), metadata={op_type="aten__sub" op_name="aten__sub.3/aten__sub" source_file="<eval_with_key>.1" source_line=5}
  %subtract.11 = s32[] subtract(s32[] %constant.9, s32[] %multiply.10), metadata={op_type="aten__sub" op_name="aten__sub.3/aten__sub" source_file="<eval_with_key>.1" source_line=5}
  %p0.12 = f32[2]{0} parameter(0), metadata={op_type="xla__device_data" op_name="xla__device_data" source_file="/workspaces/pytorch/torch/nn/modules/module.py" source_line=1326}
  %p1.13 = f32[2,2]{1,0} parameter(1), metadata={op_type="xla__device_data" op_name="xla__device_data" source_file="/workspaces/pytorch/torch/nn/modules/module.py" source_line=1326}
  %p2.15 = f32[2,2]{1,0} parameter(2), metadata={op_type="xla__device_data" op_name="xla__device_data" source_file="/workspaces/pytorch/xla/torch_xla/experimental/fori_loop.py" source_line=78}
  %transpose.14 = f32[2,2]{0,1} transpose(f32[2,2]{1,0} %p1.13), dimensions={1,0}, metadata={op_type="aten__as_strided" op_name="aten__as_strided" source_file="<eval_with_key>.1" source_line=6}
  %dot.16 = f32[2,2]{1,0} dot(f32[2,2]{1,0} %p2.15, f32[2,2]{0,1} %transpose.14), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_type="aten__addmm" op_name="aten__addmm" source_file="<eval_with_key>.1" source_line=6}
  %reshape.17 = f32[1,2]{1,0} reshape(f32[2]{0} %p0.12), metadata={op_type="aten__addmm" op_name="aten__addmm" source_file="<eval_with_key>.1" source_line=6}
  %broadcast.18 = f32[1,2]{1,0} broadcast(f32[1,2]{1,0} %reshape.17), dimensions={0,1}, metadata={op_type="aten__addmm" op_name="aten__addmm" source_file="<eval_with_key>.1" source_line=6}
  %reshape.19 = f32[2]{0} reshape(f32[1,2]{1,0} %broadcast.18), metadata={op_type="aten__addmm" op_name="aten__addmm" source_file="<eval_with_key>.1" source_line=6}
  %broadcast.20 = f32[2,2]{1,0} broadcast(f32[2]{0} %reshape.19), dimensions={1}, metadata={op_type="aten__addmm" op_name="aten__addmm" source_file="<eval_with_key>.1" source_line=6}
  %add.21 = f32[2,2]{1,0} add(f32[2,2]{1,0} %dot.16, f32[2,2]{1,0} %broadcast.20), metadata={op_type="aten__addmm" op_name="aten__addmm" source_file="<eval_with_key>.1" source_line=6}
  ROOT %tuple.23 = (s32[], f32[2]{0}, f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(s32[] %subtract.11, f32[2]{0} %p0.12, f32[2,2]{1,0} %p1.13, f32[2,2]{1,0} %add.21)
}

ENTRY %PyLoweringContext.18.25 (in.1: (f32[2], f32[2,2], f32[2,2], f32[2,2])) -> (s32[], f32[2], f32[2,2], f32[2,2]) {
  %in.1 = (f32[2]{0}, f32[2,2]{1,0}, f32[2,2]{1,0}, f32[2,2]{1,0}) parameter(0)
  %get-tuple-element.2 = f32[2]{0} get-tuple-element((f32[2]{0}, f32[2,2]{1,0}, f32[2,2]{1,0}, f32[2,2]{1,0}) %in.1), index=0
  %get-tuple-element.3 = f32[2,2]{1,0} get-tuple-element((f32[2]{0}, f32[2,2]{1,0}, f32[2,2]{1,0}, f32[2,2]{1,0}) %in.1), index=1
  %get-tuple-element.4 = f32[2,2]{1,0} get-tuple-element((f32[2]{0}, f32[2,2]{1,0}, f32[2,2]{1,0}, f32[2,2]{1,0}) %in.1), index=2
  %get-tuple-element.5 = f32[2,2]{1,0} get-tuple-element((f32[2]{0}, f32[2,2]{1,0}, f32[2,2]{1,0}, f32[2,2]{1,0}) %in.1), index=3
  ROOT %call.24 = (s32[], f32[2]{0}, f32[2,2]{1,0}, f32[2,2]{1,0}) call(f32[2]{0} %get-tuple-element.2, f32[2,2]{1,0} %get-tuple-element.3, f32[2,2]{1,0} %get-tuple-element.4, f32[2,2]{1,0} %get-tuple-element.5), to_apply=%PyLoweringContext.6
}
  • Condition
HloModule PyLoweringContext.9.16, entry_computation_layout={((s32[], f32[2]{0}, f32[2,2]{1,0}, f32[2,2]{1,0}))->pred[]}

%PyLoweringContext.6 (UnusedArgumentsPlaceholder.11: s32[], UnusedArgumentsPlaceholder.12: f32[2], UnusedArgumentsPlaceholder.13: f32[2,2], UnusedArgumentsPlaceholder.14: f32[2,2]) -> pred[] {
  %constant.8 = s32[] constant(1), metadata={op_type="prim__Constant" op_name="prim__Constant" source_file="/workspaces/pytorch/xla/test/test_while_loop.py" source_line=89}
  %convert.9 = s64[] convert(s32[] %constant.8), metadata={op_type="aten__gt" op_name="aten__gt" source_file="<eval_with_key>.0" source_line=5}
  %constant.7 = s64[] constant(0), metadata={op_type="prim__Constant" op_name="prim__Constant" source_file="<eval_with_key>.0" source_line=5}
  ROOT %compare.10 = pred[] compare(s64[] %convert.9, s64[] %constant.7), direction=GT, metadata={op_type="aten__gt" op_name="aten__gt" source_file="<eval_with_key>.0" source_line=5}
  %UnusedArgumentsPlaceholder.11 = s32[] parameter(0)
  %UnusedArgumentsPlaceholder.12 = f32[2]{0} parameter(1)
  %UnusedArgumentsPlaceholder.13 = f32[2,2]{1,0} parameter(2)
  %UnusedArgumentsPlaceholder.14 = f32[2,2]{1,0} parameter(3)
}

ENTRY %PyLoweringContext.9.16 (in.1: (s32[], f32[2], f32[2,2], f32[2,2])) -> pred[] {
  %in.1 = (s32[], f32[2]{0}, f32[2,2]{1,0}, f32[2,2]{1,0}) parameter(0)
  %get-tuple-element.2 = s32[] get-tuple-element((s32[], f32[2]{0}, f32[2,2]{1,0}, f32[2,2]{1,0}) %in.1), index=0
  %get-tuple-element.3 = f32[2]{0} get-tuple-element((s32[], f32[2]{0}, f32[2,2]{1,0}, f32[2,2]{1,0}) %in.1), index=1
  %get-tuple-element.4 = f32[2,2]{1,0} get-tuple-element((s32[], f32[2]{0}, f32[2,2]{1,0}, f32[2,2]{1,0}) %in.1), index=2
  %get-tuple-element.5 = f32[2,2]{1,0} get-tuple-element((s32[], f32[2]{0}, f32[2,2]{1,0}, f32[2,2]{1,0}) %in.1), index=3
  ROOT %call.15 = pred[] call(s32[] %get-tuple-element.2, f32[2]{0} %get-tuple-element.3, f32[2,2]{1,0} %get-tuple-element.4, f32[2,2]{1,0} %get-tuple-element.5), to_apply=%PyLoweringContext.6
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants