Skip to content

Commit

Permalink
Revert MiDaS CPU workaround for MPS
Browse files Browse the repository at this point in the history
  • Loading branch information
brkirch committed Aug 23, 2023
1 parent 943197e commit cea8af0
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions annotator/midas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,13 @@ def apply_midas(input_image, a=np.pi * 2.0, bg_th=0.1):
global model
if model is None:
model = MiDaSInference(model_type="dpt_hybrid")
if devices.get_device_for("controlnet").type != 'mps':
model = model.to(devices.get_device_for("controlnet"))
model = model.to(devices.get_device_for("controlnet"))

assert input_image.ndim == 3
image_depth = input_image
with torch.no_grad():
image_depth = torch.from_numpy(image_depth).float()
if devices.get_device_for("controlnet").type != 'mps':
image_depth = image_depth.to(devices.get_device_for("controlnet"))
image_depth = image_depth.to(devices.get_device_for("controlnet"))
image_depth = image_depth / 127.5 - 1.0
image_depth = rearrange(image_depth, 'h w c -> 1 c h w')
depth = model(image_depth)[0]
Expand Down

0 comments on commit cea8af0

Please sign in to comment.