From 5f5b51fb6259e4248bea9f3fc6da6e115eb08aad Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Mon, 15 Jul 2024 15:11:15 +0800 Subject: [PATCH] update --- annotator/mobile_sam/__init__.py | 16 +++++++++------- scripts/preprocessor/mobile_sam.py | 6 +++--- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/annotator/mobile_sam/__init__.py b/annotator/mobile_sam/__init__.py index 43b29436e..57b4124f8 100644 --- a/annotator/mobile_sam/__init__.py +++ b/annotator/mobile_sam/__init__.py @@ -16,8 +16,10 @@ class SamDetector_Aux(SamDetector): model_dir = os.path.join(models_path, "mobile_sam") - def __init__(self, mask_generator: SamAutomaticMaskGenerator): + def __init__(self, mask_generator: SamAutomaticMaskGenerator, sam): super().__init__(mask_generator) + self.device = devices.device + self.model = sam.to(self.device).eval() @classmethod def from_pretrained(cls): @@ -35,13 +37,13 @@ def from_pretrained(cls): sam = sam_model_registry["vit_t"](checkpoint=model_path) - cls.device = devices.device - cls.model = SamDetector_Aux().to(cls.device).eval() + cls.model = sam.to(devices.device).eval() - mask_generator = SamAutomaticMaskGenerator(sam) + mask_generator = SamAutomaticMaskGenerator(cls.model) - return cls(mask_generator) + return cls(mask_generator, sam) - def __call__(self, input_image: Union[np.ndarray, Image.Image]=None, detect_resolution=512, image_resolution=512, output_type="pil", **kwargs) -> np.ndarray: + def __call__(self, input_image: Union[np.ndarray, Image.Image]=None, detect_resolution=512, image_resolution=512, output_type="cv2", **kwargs) -> np.ndarray: self.model.to(self.device) - super().__call__(image=input_image, detect_resolution=detect_resolution, image_resolution=image_resolution, output_type=output_type, **kwargs) \ No newline at end of file + image = super().__call__(input_image=input_image, detect_resolution=detect_resolution, image_resolution=image_resolution, output_type=output_type, **kwargs) + return np.array(image).astype(np.uint8) \ No newline at end of file diff --git a/scripts/preprocessor/mobile_sam.py b/scripts/preprocessor/mobile_sam.py index cce7065f2..5f7cd7849 100644 --- a/scripts/preprocessor/mobile_sam.py +++ b/scripts/preprocessor/mobile_sam.py @@ -17,11 +17,11 @@ def __call__( slider_3=None, **kwargs ): - img, remove_pad = resize_image_with_pad(input_image, resolution) + #img, remove_pad = resize_image_with_pad(input_image, resolution) if self.model is None: self.model = SamDetector_Aux.from_pretrained() - result = self.model(img, detect_resolution=resolution, image_resolution=resolution) - return remove_pad(result) + result = self.model(input_image, detect_resolution=resolution, image_resolution=resolution, output_type="cv2") + return result Preprocessor.add_supported_preprocessor(PreprocessorMobileSam()) \ No newline at end of file