Skip to content

Commit

Permalink
Adding stopping strategy to model optimizer: based on time and target…
Browse files Browse the repository at this point in the history
… metric value

Summary: Adding stopping strategy to model optimizer: based on time and target metric value

Differential Revision: D49830632

fbshipit-source-id: b3d0e7f06b82337290fead8047c4b2dd8a2fabaf
  • Loading branch information
irumata authored and facebook-github-bot committed Oct 5, 2023
1 parent 5aa36cf commit a826d33
Showing 1 changed file with 31 additions and 0 deletions.
31 changes: 31 additions & 0 deletions kats/utils/time_series_parameter_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,8 @@ class SearchMethodOptions:
outcome_constraints: Optional[List[str]] = None
multiprocessing: Union[bool, int] = False
seed: Optional[int] = None
time_limit: float = -1.0
target_metric_val: Optional[float] = None


class SearchMethodFactory(metaclass=Final):
Expand Down Expand Up @@ -1171,6 +1173,22 @@ def __call__(self, optimizer: Any) -> bool: # type: ignore
return self._tolerance_count > self._tolerance_window


class _LossTargetMetricCriterion:
def __init__(self, target_metric_val: float) -> None:
self.target_metric_val: float = target_metric_val

def __call__(self, optimizer: Any) -> bool: # type: ignore
best_param = optimizer.provide_recommendation()
if best_param is None or (
best_param.loss is None and best_param._losses is None
):
return False
best_last_losses = best_param.losses
if best_last_losses is None:
return False
return best_last_losses <= self.target_metric_val


def get_nevergrad_param_from_ax(
ax_params: List[Dict[str, Any]]
) -> ng.p.Instrumentation:
Expand Down Expand Up @@ -1257,6 +1275,19 @@ def __init__(
)
),
) # should get triggered
if self.options.time_limit > 0:
self.optimizer.register_callback(
"ask",
ng.callbacks.EarlyStopping.timer(self.options.time_limit),
) # should get triggered

if self.options.target_metric_val is not None:
self.optimizer.register_callback(
"ask",
ng.callbacks.EarlyStopping(
_LossTargetMetricCriterion(self.options.target_metric_val)
),
) # should get triggered

def generate_evaluate_new_parameter_values(
self,
Expand Down

0 comments on commit a826d33

Please sign in to comment.