Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
sdbds committed Jul 15, 2024
1 parent f46298f commit 5f5b51f
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 10 deletions.
16 changes: 9 additions & 7 deletions annotator/mobile_sam/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ class SamDetector_Aux(SamDetector):

model_dir = os.path.join(models_path, "mobile_sam")

def __init__(self, mask_generator: SamAutomaticMaskGenerator):
def __init__(self, mask_generator: SamAutomaticMaskGenerator, sam):
super().__init__(mask_generator)
self.device = devices.device
self.model = sam.to(self.device).eval()

@classmethod
def from_pretrained(cls):
Expand All @@ -35,13 +37,13 @@ def from_pretrained(cls):

sam = sam_model_registry["vit_t"](checkpoint=model_path)

cls.device = devices.device
cls.model = SamDetector_Aux().to(cls.device).eval()
cls.model = sam.to(devices.device).eval()

mask_generator = SamAutomaticMaskGenerator(sam)
mask_generator = SamAutomaticMaskGenerator(cls.model)

return cls(mask_generator)
return cls(mask_generator, sam)

def __call__(self, input_image: Union[np.ndarray, Image.Image]=None, detect_resolution=512, image_resolution=512, output_type="pil", **kwargs) -> np.ndarray:
def __call__(self, input_image: Union[np.ndarray, Image.Image]=None, detect_resolution=512, image_resolution=512, output_type="cv2", **kwargs) -> np.ndarray:
self.model.to(self.device)
super().__call__(image=input_image, detect_resolution=detect_resolution, image_resolution=image_resolution, output_type=output_type, **kwargs)
image = super().__call__(input_image=input_image, detect_resolution=detect_resolution, image_resolution=image_resolution, output_type=output_type, **kwargs)
return np.array(image).astype(np.uint8)
6 changes: 3 additions & 3 deletions scripts/preprocessor/mobile_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ def __call__(
slider_3=None,
**kwargs
):
img, remove_pad = resize_image_with_pad(input_image, resolution)
#img, remove_pad = resize_image_with_pad(input_image, resolution)
if self.model is None:
self.model = SamDetector_Aux.from_pretrained()

result = self.model(img, detect_resolution=resolution, image_resolution=resolution)
return remove_pad(result)
result = self.model(input_image, detect_resolution=resolution, image_resolution=resolution, output_type="cv2")
return result

Preprocessor.add_supported_preprocessor(PreprocessorMobileSam())

0 comments on commit 5f5b51f

Please sign in to comment.