diff --git a/userbenchmark/dynamo/dynamobench/common.py b/userbenchmark/dynamo/dynamobench/common.py index 154651d4f..3d4ff7199 100644 --- a/userbenchmark/dynamo/dynamobench/common.py +++ b/userbenchmark/dynamo/dynamobench/common.py @@ -2214,6 +2214,10 @@ def skip_models_due_to_control_flow(self): def guard_on_nn_module_models(self): return set() + @property + def inline_inbuilt_nn_modules_models(self): + return set() + def get_tolerance_and_cosine_flag(self, is_training, current_device, name): raise NotImplementedError @@ -4218,16 +4222,21 @@ def detect_and_mark_batch(t): if name in runner.guard_on_nn_module_models: guard_ctx = torch._dynamo.config.patch(guard_nn_modules=True) + inline_ctx = contextlib.nullcontext() + if name in runner.inline_inbuilt_nn_modules_models: + inline_ctx = torch._dynamo.config.patch(inline_inbuilt_nn_modules=True) + with guard_ctx: - runner.run_one_model( - name, - model, - example_inputs, - optimize_ctx, - experiment, - explain=args.explain, - tag=args.tag, - ) + with inline_ctx: + runner.run_one_model( + name, + model, + example_inputs, + optimize_ctx, + experiment, + explain=args.explain, + tag=args.tag, + ) if args.generate_aot_autograd_stats: stats_file = output_filename.split(".csv")[0] + "_stats.csv" output_csv( diff --git a/userbenchmark/dynamo/dynamobench/torchbench.py b/userbenchmark/dynamo/dynamobench/torchbench.py index d7877c5a3..61175b461 100755 --- a/userbenchmark/dynamo/dynamobench/torchbench.py +++ b/userbenchmark/dynamo/dynamobench/torchbench.py @@ -217,6 +217,19 @@ def guard_on_nn_module_models(self): "vision_maskrcnn", } + @property + def inline_inbuilt_nn_modules_models(self): + return { + "basic_gnn_edgecnn", + "drq", + "hf_Reformer", + "DALLE2_pytorch", + "hf_BigBird", + "detectron2_maskrcnn_r_50_fpn", + "detectron2_maskrcnn_r_101_fpn", + "vision_maskrcnn", + } + def load_model( self, device,