From 0bcedc35035853c02403a1786c6d08814f0683fa Mon Sep 17 00:00:00 2001 From: John Pangas Date: Mon, 8 Jan 2024 00:14:45 +0300 Subject: [PATCH] Use Label Binarizer if multilabel --- bugbug/model.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/bugbug/model.py b/bugbug/model.py index ed73b955e4..fad13008f7 100644 --- a/bugbug/model.py +++ b/bugbug/model.py @@ -23,7 +23,7 @@ from sklearn.metrics import precision_recall_fscore_support from sklearn.model_selection import cross_validate, train_test_split from sklearn.pipeline import Pipeline -from sklearn.preprocessing import LabelEncoder +from sklearn.preprocessing import LabelBinarizer, LabelEncoder from tabulate import tabulate from xgboost import XGBModel @@ -372,8 +372,14 @@ def train(self, importance_cutoff=0.15, limit=None): # Extract features from the items. X = self.extraction_pipeline.transform(X_gen) - # Calculate labels. y = np.array(y) + + is_multilabel = isinstance(y[0], np.ndarray) + is_binary = len(self.class_names) == 2 + + # Calculate labels. + if is_multilabel: + self.le = LabelBinarizer() self.le.fit(y) if limit: @@ -382,9 +388,6 @@ def train(self, importance_cutoff=0.15, limit=None): logger.info(f"X: {X.shape}, y: {y.shape}") - is_multilabel = isinstance(y[0], np.ndarray) - is_binary = len(self.class_names) == 2 - # Split dataset in training and test. X_train, X_test, y_train, y_test = self.train_test_split(X, y)