Skip to content

Commit

Permalink
Merge pull request #1467 from rstudio/retether-vignettes-2024-09-04
Browse files Browse the repository at this point in the history
retether vignettes
  • Loading branch information
t-kalinowski authored Sep 4, 2024
2 parents 6ff11d2 + 8d4e9af commit 64cd480
Show file tree
Hide file tree
Showing 7 changed files with 12 additions and 15 deletions.
4 changes: 1 addition & 3 deletions .tether/vignettes-src/distribution.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,7 @@ layout_map["d1/bias"] = ("model",)
# You can also set the layout for the layer output like
layout_map["d2/output"] = ("data", None)

model_parallel = keras.distribution.ModelParallel(
mesh_2d, layout_map, batch_dim_name="data"
)
model_parallel = keras.distribution.ModelParallel(layout_map, batch_dim_name="data")

keras.distribution.set_distribution(model_parallel)

Expand Down
4 changes: 2 additions & 2 deletions .tether/vignettes-src/parked/_custom_train_step_in_torch.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
title: Customizing what happens in `fit()` with PyTorch
author: '[fchollet](https://twitter.com/fchollet)'
date-created: 2023/06/27
last-modified: 2023/06/27
last-modified: 2024/08/01
description: Overriding the training step of the Model class with PyTorch.
accelerator: GPU
output: rmarkdown::html_vignette
Expand Down Expand Up @@ -390,7 +390,7 @@ class GAN(keras.Model):

def train_step(self, real_images):
device = "cuda" if torch.cuda.is_available() else "cpu"
if isinstance(real_images, tuple):
if isinstance(real_images, tuple) or isinstance(real_images, list):
real_images = real_images[0]
# Sample random points in the latent space
batch_size = real_images.shape[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def compute_loss_and_updates(trainable_variables, non_trainable_variables, x, y)
```

Once you have such a function, you can get the gradient function by
specifying `hax_aux` in `value_and_grad`: it tells JAX that the loss
specifying `has_aux` in `value_and_grad`: it tells JAX that the loss
computation function returns more outputs than just the loss. Note that the loss
should always be the first output.

Expand Down
2 changes: 1 addition & 1 deletion .tether/vignettes-src/writing_your_own_callbacks.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ class EarlyStoppingAtMinLoss(keras.callbacks.Callback):
# The epoch the training stops at.
self.stopped_epoch = 0
# Initialize the best as infinity.
self.best = np.Inf
self.best = np.inf

def on_epoch_end(self, epoch, logs=None):
current = logs.get("loss")
Expand Down
9 changes: 4 additions & 5 deletions vignettes-src/distribution.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ mesh <- keras$distribution$DeviceMesh(
# "data" as columns, and it is a [4, 2] grid when it mapped to the physical
# devices on the mesh.
layout_2d <- keras$distribution$TensorLayout(
axes = c("model", "data"),
axes = c("model", "data"),
device_mesh = mesh
)
Expand Down Expand Up @@ -131,8 +131,8 @@ data_parallel <- keras$distribution$DataParallel(devices = devices)
# Or you can choose to create DataParallel with a 1D `DeviceMesh`.
mesh_1d <- keras$distribution$DeviceMesh(
shape = shape(8),
axis_names = list("data"),
shape = shape(8),
axis_names = list("data"),
devices = devices
)
data_parallel <- keras$distribution$DataParallel(device_mesh = mesh_1d)
Expand Down Expand Up @@ -213,8 +213,7 @@ layout_map["d1/bias"] <- tuple("model")
layout_map["d2/output"] <- tuple("data", NULL)
model_parallel <- keras$distribution$ModelParallel(
layout_map = layout_map,
batch_dim_name = "data"
layout_map, batch_dim_name = "data"
)
keras$distribution$set_distribution(model_parallel)
Expand Down
4 changes: 2 additions & 2 deletions vignettes-src/parked/_custom_train_step_in_torch.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
title: Customizing what happens in `fit()` with PyTorch
author: '[fchollet](https://twitter.com/fchollet)'
date-created: 2023/06/27
last-modified: 2023/06/27
last-modified: 2024/08/01
description: Overriding the training step of the Model class with PyTorch.
accelerator: GPU
output: rmarkdown::html_vignette
Expand Down Expand Up @@ -390,7 +390,7 @@ class GAN(keras.Model):

def train_step(self, real_images):
device = "cuda" if torch.cuda.is_available() else "cpu"
if isinstance(real_images, tuple):
if isinstance(real_images, tuple) or isinstance(real_images, list):
real_images = real_images[0]
# Sample random points in the latent space
batch_size = real_images.shape[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def compute_loss_and_updates(trainable_variables, non_trainable_variables, x, y)
```

Once you have such a function, you can get the gradient function by
specifying `hax_aux` in `value_and_grad`: it tells JAX that the loss
specifying `has_aux` in `value_and_grad`: it tells JAX that the loss
computation function returns more outputs than just the loss. Note that the loss
should always be the first output.

Expand Down

0 comments on commit 64cd480

Please sign in to comment.