Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

all models can be downloaded with one shot #1358

Merged
merged 3 commits into from
Oct 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 92 additions & 45 deletions deepface/commons/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from deepface.commons import folder_utils, package_utils
from deepface.commons.logger import Logger


tf_version = package_utils.get_tf_major_version()
if tf_version == 1:
from keras.models import Sequential
Expand All @@ -19,39 +20,7 @@

logger = Logger()

# pylint: disable=line-too-long
WEIGHTS = {
"facial_recognition": {
"VGG-Face": "https://github.com/serengil/deepface_models/releases/download/v1.0/vgg_face_weights.h5",
"Facenet": "https://github.com/serengil/deepface_models/releases/download/v1.0/facenet_weights.h5",
"Facenet512": "https://github.com/serengil/deepface_models/releases/download/v1.0/facenet512_weights.h5",
"OpenFace": "https://github.com/serengil/deepface_models/releases/download/v1.0/openface_weights.h5",
"FbDeepFace": "https://github.com/swghosh/DeepFace/releases/download/weights-vggface2-2d-aligned/VGGFace2_DeepFace_weights_val-0.9034.h5.zip",
"ArcFace": "https://github.com/serengil/deepface_models/releases/download/v1.0/arcface_weights.h5",
"DeepID": "https://github.com/serengil/deepface_models/releases/download/v1.0/deepid_keras_weights.h5",
"SFace": "https://github.com/opencv/opencv_zoo/raw/main/models/face_recognition_sface/face_recognition_sface_2021dec.onnx",
"GhostFaceNet": "https://github.com/HamadYA/GhostFaceNets/releases/download/v1.2/GhostFaceNet_W1.3_S1_ArcFace.h5",
"Dlib": "http://dlib.net/files/dlib_face_recognition_resnet_model_v1.dat.bz2",
},
"demography": {
"Age": "https://github.com/serengil/deepface_models/releases/download/v1.0/age_model_weights.h5",
"Gender": "https://github.com/serengil/deepface_models/releases/download/v1.0/gender_model_weights.h5",
"Emotion": "https://github.com/serengil/deepface_models/releases/download/v1.0/facial_expression_model_weights.h5",
"Race": "https://github.com/serengil/deepface_models/releases/download/v1.0/race_model_single_batch.h5",
},
"detection": {
"ssd_model": "https://github.com/opencv/opencv/raw/3.4.0/samples/dnn/face_detector/deploy.prototxt",
"ssd_weights": "https://github.com/opencv/opencv_3rdparty/raw/dnn_samples_face_detector_20170830/res10_300x300_ssd_iter_140000.caffemodel",
"yolo": "https://drive.google.com/uc?id=1qcr9DbgsX3ryrz2uU8w4Xm3cOrRywXqb",
"yunet": "https://github.com/opencv/opencv_zoo/raw/main/models/face_detection_yunet/face_detection_yunet_2023mar.onnx",
"dlib": "http://dlib.net/files/shape_predictor_5_face_landmarks.dat.bz2",
"centerface": "https://github.com/Star-Clouds/CenterFace/raw/master/models/onnx/centerface.onnx",
},
"spoofing": {
"MiniFASNetV2": "https://github.com/minivision-ai/Silent-Face-Anti-Spoofing/raw/master/resources/anti_spoof_models/2.7_80x80_MiniFASNetV2.pth",
"MiniFASNetV1SE": "https://github.com/minivision-ai/Silent-Face-Anti-Spoofing/raw/master/resources/anti_spoof_models/4_0_0_80x80_MiniFASNetV1SE.pth",
},
}
# pylint: disable=line-too-long, use-maxsplit-arg

ALLOWED_COMPRESS_TYPES = ["zip", "bz2"]

Expand Down Expand Up @@ -131,18 +100,96 @@ def load_model_weights(model: Sequential, weight_file: str) -> Sequential:
return model


def retrieve_model_source(model_name: str, task: str) -> str:
def download_all_models_in_one_shot() -> None:
"""
Find the source url of a given model name
Args:
model_name (str): given model name
Returns:
weight_url (str): source url of the given model
Download all model weights in one shot
"""
if task not in ["facial_recognition", "detection", "demography", "spoofing"]:
raise ValueError(f"unimplemented task - {task}")

source_url = WEIGHTS.get(task, {}).get(model_name)
if source_url is None:
raise ValueError(f"Source url cannot be found for given model {task}-{model_name}")
return source_url
# weight urls as variables
from deepface.models.facial_recognition.VGGFace import WEIGHTS_URL as VGGFACE_WEIGHTS
from deepface.models.facial_recognition.Facenet import FACENET128_WEIGHTS, FACENET512_WEIGHTS
from deepface.models.facial_recognition.OpenFace import WEIGHTS_URL as OPENFACE_WEIGHTS
from deepface.models.facial_recognition.FbDeepFace import WEIGHTS_URL as FBDEEPFACE_WEIGHTS
from deepface.models.facial_recognition.ArcFace import WEIGHTS_URL as ARCFACE_WEIGHTS
from deepface.models.facial_recognition.DeepID import WEIGHTS_URL as DEEPID_WEIGHTS
from deepface.models.facial_recognition.SFace import WEIGHTS_URL as SFACE_WEIGHTS
from deepface.models.facial_recognition.GhostFaceNet import WEIGHTS_URL as GHOSTFACENET_WEIGHTS
from deepface.models.facial_recognition.Dlib import WEIGHT_URL as DLIB_FR_WEIGHTS
from deepface.models.demography.Age import WEIGHTS_URL as AGE_WEIGHTS
from deepface.models.demography.Gender import WEIGHTS_URL as GENDER_WEIGHTS
from deepface.models.demography.Race import WEIGHTS_URL as RACE_WEIGHTS
from deepface.models.demography.Emotion import WEIGHTS_URL as EMOTION_WEIGHTS
from deepface.models.spoofing.FasNet import (
FIRST_WEIGHTS_URL as FASNET_1ST_WEIGHTS,
SECOND_WEIGHTS_URL as FASNET_2ND_WEIGHTS,
)
from deepface.models.face_detection.Ssd import (
MODEL_URL as SSD_MODEL,
WEIGHTS_URL as SSD_WEIGHTS,
)
from deepface.models.face_detection.Yolo import (
WEIGHT_URL as YOLOV8_WEIGHTS,
WEIGHT_NAME as YOLOV8_WEIGHT_NAME,
)
from deepface.models.face_detection.YuNet import WEIGHTS_URL as YUNET_WEIGHTS
from deepface.models.face_detection.Dlib import WEIGHTS_URL as DLIB_FD_WEIGHTS
from deepface.models.face_detection.CenterFace import WEIGHTS_URL as CENTERFACE_WEIGHTS

WEIGHTS = [
# facial recognition
VGGFACE_WEIGHTS,
FACENET128_WEIGHTS,
FACENET512_WEIGHTS,
OPENFACE_WEIGHTS,
FBDEEPFACE_WEIGHTS,
ARCFACE_WEIGHTS,
DEEPID_WEIGHTS,
SFACE_WEIGHTS,
{
"filename": "ghostfacenet_v1.h5",
"url": GHOSTFACENET_WEIGHTS,
},
DLIB_FR_WEIGHTS,
# demography
AGE_WEIGHTS,
GENDER_WEIGHTS,
RACE_WEIGHTS,
EMOTION_WEIGHTS,
# spoofing
FASNET_1ST_WEIGHTS,
FASNET_2ND_WEIGHTS,
# face detection
SSD_MODEL,
SSD_WEIGHTS,
{
"filename": YOLOV8_WEIGHT_NAME,
"url": YOLOV8_WEIGHTS,
},
YUNET_WEIGHTS,
DLIB_FD_WEIGHTS,
CENTERFACE_WEIGHTS,
]

for i in WEIGHTS:
if isinstance(i, str):
url = i
filename = i.split("/")[-1]
compress_type = None
# if compressed file will be downloaded, get rid of its extension
if filename.endswith(tuple(ALLOWED_COMPRESS_TYPES)):
for ext in ALLOWED_COMPRESS_TYPES:
compress_type = ext
if filename.endswith(f".{ext}"):
filename = filename[: -(len(ext) + 1)]
break
elif isinstance(i, dict):
filename = i["filename"]
url = i["url"]
else:
raise ValueError("unimplemented scenario")
logger.info(
f"Downloading {url} to ~/.deepface/weights/{filename} with {compress_type} compression"
)
download_weights_if_necessary(
file_name=filename, source_url=url, compress_type=compress_type
)
11 changes: 7 additions & 4 deletions deepface/models/demography/Age.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@

# ----------------------------------------

WEIGHTS_URL = (
"https://github.com/serengil/deepface_models/releases/download/v1.0/age_model_weights.h5"
)

# pylint: disable=too-few-public-methods
class ApparentAgeClient(Demography):
"""
Expand All @@ -41,7 +45,7 @@ def predict(self, img: np.ndarray) -> np.float64:


def load_model(
url="https://github.com/serengil/deepface_models/releases/download/v1.0/age_model_weights.h5",
url=WEIGHTS_URL,
) -> Model:
"""
Construct age model, download its weights and load
Expand Down Expand Up @@ -70,12 +74,11 @@ def load_model(
file_name="age_model_weights.h5", source_url=url
)

age_model = weight_utils.load_model_weights(
model=age_model, weight_file=weight_file
)
age_model = weight_utils.load_model_weights(model=age_model, weight_file=weight_file)

return age_model


def find_apparent_age(age_predictions: np.ndarray) -> np.float64:
"""
Find apparent age prediction from a given probas of ages
Expand Down
20 changes: 9 additions & 11 deletions deepface/models/demography/Emotion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,6 @@
from deepface.models.Demography import Demography
from deepface.commons.logger import Logger

logger = Logger()

# -------------------------------------------
# pylint: disable=line-too-long
# -------------------------------------------
# dependency configuration
tf_version = package_utils.get_tf_major_version()

Expand All @@ -28,12 +23,17 @@
Dense,
Dropout,
)
# -------------------------------------------

# Labels for the emotions that can be detected by the model.
labels = ["angry", "disgust", "fear", "happy", "sad", "surprise", "neutral"]

# pylint: disable=too-few-public-methods
logger = Logger()

# pylint: disable=line-too-long, disable=too-few-public-methods

WEIGHTS_URL = "https://github.com/serengil/deepface_models/releases/download/v1.0/facial_expression_model_weights.h5"


class EmotionClient(Demography):
"""
Emotion model class
Expand All @@ -56,7 +56,7 @@ def predict(self, img: np.ndarray) -> np.ndarray:


def load_model(
url="https://github.com/serengil/deepface_models/releases/download/v1.0/facial_expression_model_weights.h5",
url=WEIGHTS_URL,
) -> Sequential:
"""
Consruct emotion model, download and load weights
Expand Down Expand Up @@ -96,8 +96,6 @@ def load_model(
file_name="facial_expression_model_weights.h5", source_url=url
)

model = weight_utils.load_model_weights(
model=model, weight_file=weight_file
)
model = weight_utils.load_model_weights(model=model, weight_file=weight_file)

return model
5 changes: 3 additions & 2 deletions deepface/models/demography/Gender.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
else:
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Convolution2D, Flatten, Activation
# -------------------------------------

WEIGHTS_URL="https://github.com/serengil/deepface_models/releases/download/v1.0/gender_model_weights.h5"

# Labels for the genders that can be detected by the model.
labels = ["Woman", "Man"]
Expand All @@ -43,7 +44,7 @@ def predict(self, img: np.ndarray) -> np.ndarray:


def load_model(
url="https://github.com/serengil/deepface_models/releases/download/v1.0/gender_model_weights.h5",
url=WEIGHTS_URL,
) -> Model:
"""
Construct gender model, download its weights and load
Expand Down
18 changes: 9 additions & 9 deletions deepface/models/demography/Race.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,8 @@
from deepface.models.Demography import Demography
from deepface.commons.logger import Logger

logger = Logger()

# --------------------------
# pylint: disable=line-too-long
# --------------------------

# dependency configurations
tf_version = package_utils.get_tf_major_version()

Expand All @@ -21,10 +18,15 @@
else:
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Convolution2D, Flatten, Activation
# --------------------------

WEIGHTS_URL = (
"https://github.com/serengil/deepface_models/releases/download/v1.0/race_model_single_batch.h5"
)
# Labels for the ethnic phenotypes that can be detected by the model.
labels = ["asian", "indian", "black", "white", "middle eastern", "latino hispanic"]

logger = Logger()

# pylint: disable=too-few-public-methods
class RaceClient(Demography):
"""
Expand All @@ -42,7 +44,7 @@ def predict(self, img: np.ndarray) -> np.ndarray:


def load_model(
url="https://github.com/serengil/deepface_models/releases/download/v1.0/race_model_single_batch.h5",
url=WEIGHTS_URL,
) -> Model:
"""
Construct race model, download its weights and load
Expand All @@ -69,8 +71,6 @@ def load_model(
file_name="race_model_single_batch.h5", source_url=url
)

race_model = weight_utils.load_model_weights(
model=race_model, weight_file=weight_file
)
race_model = weight_utils.load_model_weights(model=race_model, weight_file=weight_file)

return race_model
3 changes: 2 additions & 1 deletion deepface/models/face_detection/Dlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

logger = Logger()

WEIGHTS_URL="http://dlib.net/files/shape_predictor_5_face_landmarks.dat.bz2"

class DlibClient(Detector):
def __init__(self):
Expand All @@ -34,7 +35,7 @@ def build_model(self) -> dict:
# check required file exists in the home/.deepface/weights folder
weight_file = weight_utils.download_weights_if_necessary(
file_name="shape_predictor_5_face_landmarks.dat",
source_url="http://dlib.net/files/shape_predictor_5_face_landmarks.dat.bz2",
source_url=WEIGHTS_URL,
compress_type="bz2",
)

Expand Down
7 changes: 5 additions & 2 deletions deepface/models/face_detection/Ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@

# pylint: disable=line-too-long, c-extension-no-member

MODEL_URL = "https://github.com/opencv/opencv/raw/3.4.0/samples/dnn/face_detector/deploy.prototxt"
WEIGHTS_URL = "https://github.com/opencv/opencv_3rdparty/raw/dnn_samples_face_detector_20170830/res10_300x300_ssd_iter_140000.caffemodel"


class SsdClient(Detector):
def __init__(self):
Expand All @@ -31,13 +34,13 @@ def build_model(self) -> dict:
# model structure
output_model = weight_utils.download_weights_if_necessary(
file_name="deploy.prototxt",
source_url="https://github.com/opencv/opencv/raw/3.4.0/samples/dnn/face_detector/deploy.prototxt",
source_url=MODEL_URL,
)

# pre-trained weights
output_weights = weight_utils.download_weights_if_necessary(
file_name="res10_300x300_ssd_iter_140000.caffemodel",
source_url="https://github.com/opencv/opencv_3rdparty/raw/dnn_samples_face_detector_20170830/res10_300x300_ssd_iter_140000.caffemodel",
source_url=WEIGHTS_URL,
)

try:
Expand Down
4 changes: 2 additions & 2 deletions deepface/models/face_detection/Yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
logger = Logger()

# Model's weights paths
PATH = ".deepface/weights/yolov8n-face.pt"
WEIGHT_NAME = "yolov8n-face.pt"

# Google Drive URL from repo (https://github.com/derronqi/yolov8-face) ~6MB
WEIGHT_URL = "https://drive.google.com/uc?id=1qcr9DbgsX3ryrz2uU8w4Xm3cOrRywXqb"
Expand All @@ -39,7 +39,7 @@ def build_model(self) -> Any:
) from e

weight_file = weight_utils.download_weights_if_necessary(
file_name="yolov8n-face.pt", source_url=WEIGHT_URL
file_name=WEIGHT_NAME, source_url=WEIGHT_URL
)

# Return face_detector
Expand Down
5 changes: 4 additions & 1 deletion deepface/models/face_detection/YuNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@

logger = Logger()

# pylint:disable=line-too-long
WEIGHTS_URL = "https://github.com/opencv/opencv_zoo/raw/main/models/face_detection_yunet/face_detection_yunet_2023mar.onnx"


class YuNetClient(Detector):
def __init__(self):
Expand Down Expand Up @@ -41,7 +44,7 @@ def build_model(self) -> Any:
# pylint: disable=C0301
weight_file = weight_utils.download_weights_if_necessary(
file_name="face_detection_yunet_2023mar.onnx",
source_url="https://github.com/opencv/opencv_zoo/raw/main/models/face_detection_yunet/face_detection_yunet_2023mar.onnx",
source_url=WEIGHTS_URL,
)

try:
Expand Down
Loading