Skip to content

Commit

Permalink
Merge branch 'main' into fix-torchrl-deps
Browse files Browse the repository at this point in the history
  • Loading branch information
svekars authored Jul 31, 2024
2 parents e3c7f14 + f5c28eb commit 1feb07f
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 14 deletions.
7 changes: 5 additions & 2 deletions .ci/docker/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ pytorch-lightning
torchx
torchrl==0.5.0
tensordict==0.5.0
ax-platform
# TODO: use stable 0.5 when released
-e git+https://github.com/pytorch/rl.git#egg=torchrl
-e git+https://github.com/pytorch/tensordict.git#egg=tensordict
ax-platform>==0.4.0
nbformat>==5.9.2
datasets
transformers
Expand Down Expand Up @@ -68,4 +71,4 @@ pygame==2.1.2
pycocotools
semilearn==0.3.2
torchao==0.0.3
segment_anything==1.0
segment_anything==1.0
22 changes: 11 additions & 11 deletions intermediate_source/ax_multiobjective_nas_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,21 +232,21 @@ def trainer(
# we get the logic to read and parse the TensorBoard logs for free.
#

from ax.metrics.tensorboard import TensorboardCurveMetric
from ax.metrics.tensorboard import TensorboardMetric
from tensorboard.backend.event_processing import plugin_event_multiplexer as event_multiplexer


class MyTensorboardMetric(TensorboardCurveMetric):
class MyTensorboardMetric(TensorboardMetric):

# NOTE: We need to tell the new TensorBoard metric how to get the id /
# file handle for the TensorBoard logs from a trial. In this case
# our convention is to just save a separate file per trial in
# the prespecified log dir.
@classmethod
def get_ids_from_trials(cls, trials):
return {
trial.index: Path(log_dir).joinpath(str(trial.index)).as_posix()
for trial in trials
}
def _get_event_multiplexer_for_trial(self, trial):
mul = event_multiplexer.EventMultiplexer(max_reload_threads=20)
mul.AddRunsFromDirectory(Path(log_dir).joinpath(str(trial.index)).as_posix(), None)
mul.Reload()

return mul

# This indicates whether the metric is queryable while the trial is
# still running. We don't use this in the current tutorial, but Ax
Expand All @@ -266,12 +266,12 @@ def is_available_while_running(cls):

val_acc = MyTensorboardMetric(
name="val_acc",
curve_name="val_acc",
tag="val_acc",
lower_is_better=False,
)
model_num_params = MyTensorboardMetric(
name="num_params",
curve_name="num_params",
tag="num_params",
lower_is_better=True,
)

Expand Down
2 changes: 1 addition & 1 deletion prototype_source/pt2e_quantizer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Prerequisites:

Required:

- `Torchdynamo concepts in PyTorch <https://pytorch.org/docs/stable/dynamo/index.html>`__
- `Torchdynamo concepts in PyTorch <https://pytorch.org/docs/stable/torch.compiler_dynamo_overview.html>`__

- `Quantization concepts in PyTorch <https://pytorch.org/docs/master/quantization.html#quantization-api-summary>`__

Expand Down

0 comments on commit 1feb07f

Please sign in to comment.