Skip to content

Commit

Permalink
Use Label Binarizer if multilabel
Browse files Browse the repository at this point in the history
  • Loading branch information
jpangas committed Jan 7, 2024
1 parent 0abfdcf commit 0bcedc3
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions bugbug/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from sklearn.metrics import precision_recall_fscore_support
from sklearn.model_selection import cross_validate, train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import LabelBinarizer, LabelEncoder
from tabulate import tabulate
from xgboost import XGBModel

Expand Down Expand Up @@ -372,8 +372,14 @@ def train(self, importance_cutoff=0.15, limit=None):
# Extract features from the items.
X = self.extraction_pipeline.transform(X_gen)

# Calculate labels.
y = np.array(y)

is_multilabel = isinstance(y[0], np.ndarray)
is_binary = len(self.class_names) == 2

# Calculate labels.
if is_multilabel:
self.le = LabelBinarizer()
self.le.fit(y)

if limit:
Expand All @@ -382,9 +388,6 @@ def train(self, importance_cutoff=0.15, limit=None):

logger.info(f"X: {X.shape}, y: {y.shape}")

is_multilabel = isinstance(y[0], np.ndarray)
is_binary = len(self.class_names) == 2

# Split dataset in training and test.
X_train, X_test, y_train, y_test = self.train_test_split(X, y)

Expand Down

0 comments on commit 0bcedc3

Please sign in to comment.