Skip to content

Commit

Permalink
apply Black 2024 style in fbcode (15/17)
Browse files Browse the repository at this point in the history
Summary:
Formats the covered files with pyfmt.

paintitblack

Reviewed By: zertosh

Differential Revision: D54470857

fbshipit-source-id: 6e68f9cec670e75777c696a5a47731cedab782c6
  • Loading branch information
amyreese authored and facebook-github-bot committed Mar 4, 2024
1 parent 658294c commit 56d17d3
Show file tree
Hide file tree
Showing 28 changed files with 159 additions and 126 deletions.
6 changes: 2 additions & 4 deletions kats/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,11 @@ def load_data(file_name: str, reset_columns: bool = False) -> pd.DataFrame:


@overload
def load_air_passengers(return_ts: Literal[True]) -> TimeSeriesData:
...
def load_air_passengers(return_ts: Literal[True]) -> TimeSeriesData: ...


@overload
def load_air_passengers(return_ts: Literal[False] = ...) -> pd.DataFrame:
...
def load_air_passengers(return_ts: Literal[False] = ...) -> pd.DataFrame: ...


def load_air_passengers(return_ts: bool = True) -> Union[pd.DataFrame, TimeSeriesData]:
Expand Down
10 changes: 6 additions & 4 deletions kats/detectors/cusum_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1224,10 +1224,12 @@ def detector_(self, **kwargs: Any) -> List[List[CUSUMChangePoint]]:
if not list(change_meta_["changepoint"]):
continue
change_meta = {
k: change_meta_[k][col_idx]
if isinstance(change_meta_[k], np.ndarray)
or isinstance(change_meta_[k], list)
else change_meta_[k]
k: (
change_meta_[k][col_idx]
if isinstance(change_meta_[k], np.ndarray)
or isinstance(change_meta_[k], list)
else change_meta_[k]
)
for k in change_meta_
}
change_meta["llr"] = llr = self._get_llr(
Expand Down
8 changes: 5 additions & 3 deletions kats/detectors/cusum_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1276,9 +1276,11 @@ def _fit(
)

cps = [
sorted(x, key=lambda x: x.start_time)[0]
if x and alert_set_on_mask[i]
else None
(
sorted(x, key=lambda x: x.start_time)[0]
if x and alert_set_on_mask[i]
else None
)
for i, x in enumerate(changepoints)
]

Expand Down
22 changes: 13 additions & 9 deletions kats/detectors/detector_consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,10 +260,12 @@ def stat_sig(self) -> Union[bool, ArrayLike]:
if self.num_series > 1:
return np.array(
[
False
if cast(np.ndarray, self.upper)[i] > 1.0
and cast(np.ndarray, self.lower)[i] < 1
else True
(
False
if cast(np.ndarray, self.upper)[i] > 1.0
and cast(np.ndarray, self.lower)[i] < 1
else True
)
for i in range(self.current.num_series)
]
)
Expand Down Expand Up @@ -649,11 +651,13 @@ def get_last_n(self, N: int) -> AnomalyResponse:

return AnomalyResponse(
scores=self.scores[-N:],
confidence_band=None
if cb is None
else ConfidenceBand(
upper=cb.upper[-N:],
lower=cb.lower[-N:],
confidence_band=(
None
if cb is None
else ConfidenceBand(
upper=cb.upper[-N:],
lower=cb.lower[-N:],
)
),
predicted_ts=None if pts is None else pts[-N:],
anomaly_magnitude_ts=self.anomaly_magnitude_ts[-N:],
Expand Down
10 changes: 6 additions & 4 deletions kats/detectors/distribution_distance_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,10 +260,12 @@ def fit_predict(
window=str(self.window_size_sec) + "s",
closed="both",
).agg(
lambda rows: rows[0]
if (rows.index[-1] - rows.index[0]).total_seconds()
> 0.9 * self.window_size_sec # tolerance
else np.nan
lambda rows: (
rows[0]
if (rows.index[-1] - rows.index[0]).total_seconds()
> 0.9 * self.window_size_sec # tolerance
else np.nan
)
)

# exclude the beginning part of NANs
Expand Down
8 changes: 5 additions & 3 deletions kats/detectors/meta_learning/synth_metadata_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,11 @@ def get_metadata(self, algorithm_name: str) -> Dict[str, pd.DataFrame]:
.map(lambda kv: kv[a][0])
.map(
lambda kv: {
k: v
if k not in self.PARAMS_TO_SCALE_DOWN
else v / SynthMetadataReader.NUM_SECS_IN_DAY
k: (
v
if k not in self.PARAMS_TO_SCALE_DOWN
else v / SynthMetadataReader.NUM_SECS_IN_DAY
)
for k, v in kv.items()
}
)
Expand Down
2 changes: 1 addition & 1 deletion kats/detectors/outlier.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(

def __clean_ts__(
self,
original: Union[pd.Series, pd.DataFrame]
original: Union[pd.Series, pd.DataFrame],
# pyre-fixme[11]: Annotation `Timestamp` is not defined as a type.
) -> Tuple[List[int], List[float], List[pd.Timestamp]]:
"""
Expand Down
2 changes: 2 additions & 0 deletions kats/detectors/prophet_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
import sys

NOT_SUPPRESS_PROPHET_FIT_LOGS_VAR_NAME = "NOT_SUPPRESS_PROPHET_FIT_LOGS"


# this is a bug in prophet which was discussed in open source thread
# issues was also suggested
# details https://github.com/facebook/prophet/issues/223#issuecomment-326455744
Expand Down
1 change: 0 additions & 1 deletion kats/detectors/stat_sig_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,7 +922,6 @@ def predict(


class MultiStatSigDetectorModel(StatSigDetectorModel):

"""
MultiStatSigDetectorModel is a multivariate version of the StatSigDetector. It applies a univariate
t-test to each of the components of the multivariate time series to see if the means between the control
Expand Down
21 changes: 9 additions & 12 deletions kats/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
# from numpy.typing import ArrayLike
ArrayLike = Union[np.ndarray, Sequence[float]]


# Type aliases
#
# Most metrics have the shape:
Expand All @@ -47,13 +48,13 @@ def __call__(
y_true: ArrayLike,
y_pred: ArrayLike,
sample_weight: Optional[ArrayLike] = ...,
) -> np.ndarray:
... # pragma: no cover
) -> np.ndarray: ... # pragma: no cover


class Metric(Protocol):
def __call__(self, y_true: ArrayLike, y_pred: ArrayLike) -> float:
... # pragma: no cover
def __call__(
self, y_true: ArrayLike, y_pred: ArrayLike
) -> float: ... # pragma: no cover


class WeightedMetric(Protocol):
Expand All @@ -62,8 +63,7 @@ def __call__(
y_true: ArrayLike,
y_pred: ArrayLike,
sample_weight: Optional[ArrayLike] = ...,
) -> float:
... # pragma: no cover
) -> float: ... # pragma: no cover


class MultiOutputMetric(Protocol):
Expand All @@ -73,8 +73,7 @@ def __call__(
y_pred: ArrayLike,
sample_weight: Optional[ArrayLike] = ...,
multioutput: Union[str, ArrayLike] = ...,
) -> float:
... # pragma: no cover
) -> float: ... # pragma: no cover


class ThresholdMetric(Protocol):
Expand All @@ -83,8 +82,7 @@ def __call__(
y_true: ArrayLike,
y_pred: ArrayLike,
threshold: float,
) -> float:
... # pragma: no cover
) -> float: ... # pragma: no cover


class MultiThresholdMetric(Protocol):
Expand All @@ -93,8 +91,7 @@ def __call__(
y_true: ArrayLike,
y_pred: ArrayLike,
threshold: ArrayLike,
) -> np.ndarray:
... # pragma: no cover
) -> np.ndarray: ... # pragma: no cover


KatsMetric = Union[
Expand Down
2 changes: 1 addition & 1 deletion kats/models/ensemble/kats_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,7 @@ def _fit_single(
data: TimeSeriesData,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
model_func: Callable,
model_param: Params
model_param: Params,
# pyre-fixme[24]: Generic type `Model` expects 1 type parameter.
) -> Model:
"""Private method to fit individual model
Expand Down
13 changes: 6 additions & 7 deletions kats/models/globalmodel/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,12 +263,12 @@ def __init__(
if params.model_type == "rnn" and params.seasonality > 1:
init_seasonality = self._get_seasonality(train_x, params.seasonality)
# bound initial seasonalities
init_seasonality[
init_seasonality < params.init_seasonality[0]
] = params.init_seasonality[0]
init_seasonality[
init_seasonality > params.init_seasonality[1]
] = params.init_seasonality[1]
init_seasonality[init_seasonality < params.init_seasonality[0]] = (
params.init_seasonality[0]
)
init_seasonality[init_seasonality > params.init_seasonality[1]] = (
params.init_seasonality[1]
)
# pyre-fixme[4]: Attribute must be annotated.
self.init_seasonality = torch.tensor(init_seasonality, dtype=tdtype)
else:
Expand Down Expand Up @@ -451,7 +451,6 @@ def _get_array(
Optional[np.ndarray],
Optional[np.ndarray],
]:

"""
Helper function for transforming TS to arrays, including truncating/padding values,
Expand Down
22 changes: 12 additions & 10 deletions kats/models/globalmodel/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,9 +545,7 @@ def _format_fcst(
for i, idx in enumerate(ids):

df = pd.DataFrame(
fcst[i].transpose()[
:steps,
],
fcst[i].transpose()[:steps,],
columns=cols,
)
df["time"] = pd.date_range(
Expand Down Expand Up @@ -651,12 +649,12 @@ def save_model(self, file_name: str) -> None:
info = {
"gmparam_string": self.params.to_string(),
"state_dict": self.rnn.state_dict() if self.rnn is not None else None,
"encoder_state_dict": self.encoder.state_dict()
if self.encoder is not None
else None,
"decoder_state_dict": self.decoder.state_dict()
if self.decoder is not None
else None,
"encoder_state_dict": (
self.encoder.state_dict() if self.encoder is not None else None
),
"decoder_state_dict": (
self.decoder.state_dict() if self.decoder is not None else None
),
}
with open(file_name, "wb") as f:
joblib.dump(info, f)
Expand Down Expand Up @@ -1099,7 +1097,11 @@ def _single_pass_s2s(
cur_step + 1
)

(x_t, anchor_level, x_lt,) = self._process_s2s(
(
x_t,
anchor_level,
x_lt,
) = self._process_s2s(
prev_idx, cur_idx, batch.x, x_lt, period, params.input_window
)

Expand Down
1 change: 0 additions & 1 deletion kats/models/globalmodel/stdmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,6 @@ def _deseasonal(
def _predict_seasonality(
self, steps: int, tsd_model: Union[ProphetModel, np.ndarray]
) -> np.ndarray:

"""Predict the future seasonality.
Args:
Expand Down
18 changes: 11 additions & 7 deletions kats/models/globalmodel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def get_filters(isna_idx, seasonality) -> np.ndarray:
else:
i += 1
filters = np.array([True] * n)
for (i, j) in flips:
for i, j in flips:
filters[i:j] = False
return filters

Expand Down Expand Up @@ -188,19 +188,23 @@ def split(
split_data = [
(
{t: train_TSs[t] for t in keys[~index[i]]},
{t: valid_TSs[t] for t in keys[~index[i]]}
if valid_TSs is not None
else None,
(
{t: valid_TSs[t] for t in keys[~index[i]]}
if valid_TSs is not None
else None
),
)
for i in range(splits)
]
else:
split_data = [
(
{t: train_TSs[t] for t in keys[index[i]]},
{t: valid_TSs[t] for t in keys[index[i]]}
if valid_TSs is not None
else None,
(
{t: valid_TSs[t] for t in keys[index[i]]}
if valid_TSs is not None
else None
),
)
for i in range(splits)
]
Expand Down
2 changes: 1 addition & 1 deletion kats/models/metalearner/get_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def _tune_single(
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
single_model: Callable,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
single_params: Callable
single_params: Callable,
# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
# `typing.Dict` to avoid runtime subscripting errors.
) -> Tuple[Dict, float]:
Expand Down
4 changes: 1 addition & 3 deletions kats/models/metalearner/metalearner_hpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,9 +456,7 @@ def build_network(
print("Multi-task neural network structure:")
print(self.model)

def _prepare_data(
self, val_size: float
) -> Tuple[
def _prepare_data(self, val_size: float) -> Tuple[
torch.FloatTensor,
Optional[torch.LongTensor],
Optional[torch.FloatTensor],
Expand Down
6 changes: 3 additions & 3 deletions kats/models/ml_ar.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,9 +934,9 @@ def _merge_past_and_future_reg(
num_rows_dat = norm_in_data[target_var].shape[0]
num_cols_dat = norm_in_data[target_var].shape[1]

full_mat[
tv_idx : (tv_idx + num_rows_dat), 0:num_cols_dat
] = norm_in_data[target_var]
full_mat[tv_idx : (tv_idx + num_rows_dat), 0:num_cols_dat] = (
norm_in_data[target_var]
)

full_mat[
tv_idx : (tv_idx + num_rows_dat),
Expand Down
Loading

0 comments on commit 56d17d3

Please sign in to comment.