Skip to content

Commit

Permalink
Merge pull request #87 from MilagrosMarin/train_bug
Browse files Browse the repository at this point in the history
Fix project path in the pose config file
  • Loading branch information
kabilar authored Aug 4, 2023
2 parents 4046d89 + 81eb296 commit 78ac980
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 19 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and
[Keep a Changelog](https://keepachangelog.com/en/1.0.0/) convention.

## [0.2.7] - 2023-08-04

+ Fix - Update the project path in the pose config file to train the model

## [0.2.6] - 2023-05-22

+ Add - DeepLabCut, NWB, and DANDI citations
Expand Down Expand Up @@ -68,6 +72,7 @@ Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and
graciously provided by the Mathis Lab.
+ Add - Support for 2d single-animal models

[0.2.7]: https://github.com/datajoint/element-deeplabcut/releases/tag/0.2.7
[0.2.6]: https://github.com/datajoint/element-deeplabcut/releases/tag/0.2.6
[0.2.5]: https://github.com/datajoint/element-deeplabcut/releases/tag/0.2.5
[0.2.4]: https://github.com/datajoint/element-deeplabcut/releases/tag/0.2.4
Expand Down
44 changes: 26 additions & 18 deletions element_deeplabcut/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,13 +241,17 @@ class ModelTraining(dj.Computed):
# https://github.com/DeepLabCut/DeepLabCut/issues/70

def make(self, key):
from deeplabcut import train_network # isort:skip
from deeplabcut import train_network # isort:skip

try:
from deeplabcut.utils.auxiliaryfunctions import get_model_folder # isort:skip
from deeplabcut.utils.auxiliaryfunctions import (
get_model_folder,
edit_config,
) # isort:skip
except ImportError:
from deeplabcut.utils.auxiliaryfunctions import (
GetModelFolder as get_model_folder
) # isort:skip
GetModelFolder as get_model_folder,
) # isort:skip

"""Launch training for each train.TrainingTask training_id via `.populate()`."""
project_path, model_prefix = (TrainingTask & key).fetch1(
Expand Down Expand Up @@ -275,11 +279,26 @@ def make(self, key):
# Write dlc config file to base project folder
dlc_cfg_filepath = dlc_reader.save_yaml(project_path, dlc_config)

# ---- Update the project path in the DLC pose configuration (yaml) files ----
model_folder = get_model_folder(
trainFraction=dlc_config["train_fraction"],
shuffle=dlc_config["shuffle"],
cfg=dlc_config,
modelprefix=dlc_config["modelprefix"],
)
model_train_folder = project_path / model_folder / "train"

edit_config(
model_train_folder / "pose_cfg.yaml",
{"project_path": project_path.as_posix()},
)

# ---- Trigger DLC model training job ----
train_network_input_args = list(inspect.signature(train_network).parameters)
train_network_kwargs = {
k: int(v) if k in ("shuffle", "trainingsetindex", "maxiters") else v
for k, v in dlc_config.items() if k in train_network_input_args
k: int(v) if k in ("shuffle", "trainingsetindex", "maxiters") else v
for k, v in dlc_config.items()
if k in train_network_input_args
}
for k in ["shuffle", "trainingsetindex", "maxiters"]:
train_network_kwargs[k] = int(train_network_kwargs[k])
Expand All @@ -289,18 +308,7 @@ def make(self, key):
except KeyboardInterrupt: # Instructions indicate to train until interrupt
print("DLC training stopped via Keyboard Interrupt")

snapshots = list(
(
project_path
/ get_model_folder(
trainFraction=dlc_config["train_fraction"],
shuffle=dlc_config["shuffle"],
cfg=dlc_config,
modelprefix=dlc_config["modelprefix"],
)
/ "train"
).glob("*index*")
)
snapshots = list(model_train_folder.glob("*index*"))
max_modified_time = 0
# DLC goes by snapshot magnitude when judging 'latest' for evaluation
# Here, we mean most recently generated
Expand Down
2 changes: 1 addition & 1 deletion element_deeplabcut/version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""
Package metadata
"""
__version__ = "0.2.6"
__version__ = "0.2.7"

0 comments on commit 78ac980

Please sign in to comment.