From 81eb29654b202bbb324908ccb4e68c9dc66a58d3 Mon Sep 17 00:00:00 2001 From: Milagros Marin Date: Fri, 4 Aug 2023 17:07:05 -0500 Subject: [PATCH] Fix project path --- CHANGELOG.md | 5 ++++ element_deeplabcut/train.py | 44 +++++++++++++++++++++-------------- element_deeplabcut/version.py | 2 +- 3 files changed, 32 insertions(+), 19 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b264da9..a81eb73 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,10 @@ Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) convention. +## [0.2.7] - 2023-08-04 + ++ Fix - Update the project path in the pose config file to train the model + ## [0.2.6] - 2023-05-22 + Add - DeepLabCut, NWB, and DANDI citations @@ -68,6 +72,7 @@ Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and graciously provided by the Mathis Lab. + Add - Support for 2d single-animal models +[0.2.7]: https://github.com/datajoint/element-deeplabcut/releases/tag/0.2.7 [0.2.6]: https://github.com/datajoint/element-deeplabcut/releases/tag/0.2.6 [0.2.5]: https://github.com/datajoint/element-deeplabcut/releases/tag/0.2.5 [0.2.4]: https://github.com/datajoint/element-deeplabcut/releases/tag/0.2.4 diff --git a/element_deeplabcut/train.py b/element_deeplabcut/train.py index ff85146..cb37a4c 100644 --- a/element_deeplabcut/train.py +++ b/element_deeplabcut/train.py @@ -241,13 +241,17 @@ class ModelTraining(dj.Computed): # https://github.com/DeepLabCut/DeepLabCut/issues/70 def make(self, key): - from deeplabcut import train_network # isort:skip + from deeplabcut import train_network # isort:skip + try: - from deeplabcut.utils.auxiliaryfunctions import get_model_folder # isort:skip + from deeplabcut.utils.auxiliaryfunctions import ( + get_model_folder, + edit_config, + ) # isort:skip except ImportError: from deeplabcut.utils.auxiliaryfunctions import ( - GetModelFolder as get_model_folder - ) # isort:skip + GetModelFolder as get_model_folder, + ) # isort:skip """Launch training for each train.TrainingTask training_id via `.populate()`.""" project_path, model_prefix = (TrainingTask & key).fetch1( @@ -275,11 +279,26 @@ def make(self, key): # Write dlc config file to base project folder dlc_cfg_filepath = dlc_reader.save_yaml(project_path, dlc_config) + # ---- Update the project path in the DLC pose configuration (yaml) files ---- + model_folder = get_model_folder( + trainFraction=dlc_config["train_fraction"], + shuffle=dlc_config["shuffle"], + cfg=dlc_config, + modelprefix=dlc_config["modelprefix"], + ) + model_train_folder = project_path / model_folder / "train" + + edit_config( + model_train_folder / "pose_cfg.yaml", + {"project_path": project_path.as_posix()}, + ) + # ---- Trigger DLC model training job ---- train_network_input_args = list(inspect.signature(train_network).parameters) train_network_kwargs = { - k: int(v) if k in ("shuffle", "trainingsetindex", "maxiters") else v - for k, v in dlc_config.items() if k in train_network_input_args + k: int(v) if k in ("shuffle", "trainingsetindex", "maxiters") else v + for k, v in dlc_config.items() + if k in train_network_input_args } for k in ["shuffle", "trainingsetindex", "maxiters"]: train_network_kwargs[k] = int(train_network_kwargs[k]) @@ -289,18 +308,7 @@ def make(self, key): except KeyboardInterrupt: # Instructions indicate to train until interrupt print("DLC training stopped via Keyboard Interrupt") - snapshots = list( - ( - project_path - / get_model_folder( - trainFraction=dlc_config["train_fraction"], - shuffle=dlc_config["shuffle"], - cfg=dlc_config, - modelprefix=dlc_config["modelprefix"], - ) - / "train" - ).glob("*index*") - ) + snapshots = list(model_train_folder.glob("*index*")) max_modified_time = 0 # DLC goes by snapshot magnitude when judging 'latest' for evaluation # Here, we mean most recently generated diff --git a/element_deeplabcut/version.py b/element_deeplabcut/version.py index 21ce3f8..617ac09 100644 --- a/element_deeplabcut/version.py +++ b/element_deeplabcut/version.py @@ -1,4 +1,4 @@ """ Package metadata """ -__version__ = "0.2.6" +__version__ = "0.2.7"