From d1ffb2fcf3853cdbd78d371c90a9cd4c8b18cba0 Mon Sep 17 00:00:00 2001 From: Ian Randman <36488576+ianrandman@users.noreply.github.com> Date: Mon, 1 Jul 2024 05:25:37 -0400 Subject: [PATCH] Simplify zero-shot topic modeling (#2060) --- bertopic/_bertopic.py | 417 +++++++++++++++++++---------------------- tests/conftest.py | 2 +- tests/test_bertopic.py | 11 ++ 3 files changed, 206 insertions(+), 224 deletions(-) diff --git a/bertopic/_bertopic.py b/bertopic/_bertopic.py index 4a94d139..5682f40e 100644 --- a/bertopic/_bertopic.py +++ b/bertopic/_bertopic.py @@ -274,7 +274,7 @@ def __init__( self.topic_mapper_ = None self.topic_representations_ = None self.topic_embeddings_ = None - self.topic_labels_ = None + self._topic_id_to_zeroshot_topic_idx = {} self.custom_labels_ = None self.c_tf_idf_ = None self.representative_images_ = None @@ -282,7 +282,6 @@ def __init__( self.topic_aspects_ = {} # Private attributes for internal tracking purposes - self._outliers = 1 self._merged_topics = None if verbose: @@ -290,6 +289,39 @@ def __init__( else: logger.set_level("WARNING") + @property + def _outliers(self): + """Some algorithms have outlier labels (-1) that can be tricky to work + with if you are slicing data based on that labels. Therefore, we + track if there are outlier labels and act accordingly when slicing. + + Returns: + An integer indicating whether outliers are present in the topic model + """ + return 1 if -1 in self.topic_sizes_ else 0 + + @property + def topic_labels_(self): + """Map topic IDs to their labels. + A label is the topic ID, along with the first four words of the topic representation, joined using '_'. + Zeroshot topic labels come from self.zeroshot_topic_list rather than the calculated representation. + + Returns: + topic_labels: a dict mapping a topic ID (int) to its label (str) + """ + topic_labels = { + key: f"{key}_" + "_".join([word[0] for word in values[:4]]) + for key, values in self.topic_representations_.items() + } + if self._is_zeroshot(): + # Need to correct labels from zero-shot topics + topic_id_to_zeroshot_label = { + topic_id: self.zeroshot_topic_list[zeroshot_topic_idx] + for topic_id, zeroshot_topic_idx in self._topic_id_to_zeroshot_topic_idx.items() + } + topic_labels.update(topic_id_to_zeroshot_label) + return topic_labels + def fit( self, documents: List[str], @@ -422,23 +454,31 @@ def fit_transform( if self.seed_topic_list is not None and self.embedding_model is not None: y, embeddings = self._guided_topic_modeling(embeddings) + # Reduce dimensionality and fit UMAP model + umap_embeddings = self._reduce_dimensionality(embeddings, y) + # Zero-shot Topic Modeling if self._is_zeroshot(): documents, embeddings, assigned_documents, assigned_embeddings = ( self._zeroshot_topic_modeling(documents, embeddings) ) - if documents is None: - return self._combine_zeroshot_topics( - documents, assigned_documents, assigned_embeddings - ) - - # Reduce dimensionality - umap_embeddings = self._reduce_dimensionality(embeddings, y) + # Filter UMAP embeddings to only non-assigned embeddings to be used for clustering + umap_embeddings = self.umap_model.transform(embeddings) - # Cluster reduced embeddings - documents, probabilities = self._cluster_embeddings( - umap_embeddings, documents, y=y - ) + if len(documents) > 0: # No zero-shot topics matched + # Cluster reduced embeddings + documents, probabilities = self._cluster_embeddings( + umap_embeddings, documents, y=y + ) + if self._is_zeroshot() and len(assigned_documents) > 0: + documents, embeddings = self._combine_zeroshot_topics( + documents, embeddings, assigned_documents, assigned_embeddings + ) + else: + # All documents matches zero-shot topics + documents = assigned_documents + embeddings = assigned_embeddings + topics_before_reduction = self.topics_ # Sort and Map Topic IDs by their frequency if not self.nr_topics: @@ -469,18 +509,28 @@ def fit_transform( # Save the top 3 most representative documents per topic self._save_representative_docs(documents) + # In the case of zero-shot topics, probability will come from cosine similarity, + # and the HDBSCAN model will be removed + if self._is_zeroshot() and len(assigned_documents) > 0: + self.hdbscan_model = BaseCluster() + sim_matrix = cosine_similarity(embeddings, np.array(self.topic_embeddings_)) + + if self.calculate_probabilities: + probabilities = sim_matrix + else: + # Use `topics_before_reduction` because `self.topics_` may have already been updated from + # reducing topics, and the original probabilities are needed for `self._map_probabilities()` + probabilities = sim_matrix[ + np.arange(len(documents)), + np.array(topics_before_reduction) + self._outliers, + ] + # Resulting output self.probabilities_ = self._map_probabilities( probabilities, original_topics=True ) predictions = documents.Topic.to_list() - # Combine Zero-shot with outliers - if self._is_zeroshot() and len(documents) != len(doc_ids): - predictions = self._combine_zeroshot_topics( - documents, assigned_documents, assigned_embeddings - ) - return predictions, self.probabilities_ def transform( @@ -740,10 +790,6 @@ def partial_fit( updated_words, documents, self.c_tf_idf_, calculate_aspects=False ) self._create_topic_vectors() - self.topic_labels_ = { - key: f"{key}_" + "_".join([word[0] for word in values[:4]]) - for key, values in self.topic_representations_.items() - } # Update topic sizes if len(missing_topics) > 0: @@ -1598,15 +1644,17 @@ def update_topics( "c-TF-IDF embeddings instead of centroid embeddings." ) - self._outliers = 1 if -1 in set(topics) else 0 - - # Extract words documents = pd.DataFrame( {"Document": docs, "Topic": topics, "ID": range(len(docs)), "Image": images} ) documents_per_topic = documents.groupby(["Topic"], as_index=False).agg( {"Document": " ".join} ) + + # Update topic sizes and assignments + self._update_topic_size(documents) + + # Extract words and update topic labels self.c_tf_idf_, words = self._c_tf_idf(documents_per_topic) self.topic_representations_ = self._extract_words_per_topic(words, documents) @@ -1625,13 +1673,6 @@ def update_topics( else: self._create_topic_vectors() - # Update topic labels - self.topic_labels_ = { - key: f"{key}_" + "_".join([word[0] for word in values[:4]]) - for key, values in self.topic_representations_.items() - } - self._update_topic_size(documents) - def get_topics(self, full: bool = False) -> Mapping[str, Tuple[str, float]]: """Return topics with top n words and their c-TF-IDF score. @@ -2251,11 +2292,11 @@ def merge_topics( for key, val in sorted(mapping.items()): mappings[val].append(key) mappings = { - topic_from: { - "topics_to": topics_to, - "topic_sizes": [self.topic_sizes_[topic] for topic in topics_to], + topic_to: { + "topics_from": topics_from, + "topic_sizes": [self.topic_sizes_[topic] for topic in topics_from], } - for topic_from, topics_to in mappings.items() + for topic_to, topics_from in mappings.items() } # Update topics @@ -2423,6 +2464,9 @@ def reduce_outliers( new_topics = topic_model.reduce_outliers(docs, topics, probabilities=probs, strategy="probabilities") ``` """ + if not self._outliers: + raise ValueError("No outliers to reduce.") + if images is not None: strategy = "embeddings" @@ -3959,11 +4003,6 @@ def _cluster_embeddings( documents["Topic"] = labels self._update_topic_size(documents) - # Some algorithms have outlier labels (-1) that can be tricky to work - # with if you are slicing data based on that labels. Therefore, we - # track if there are outlier labels and act accordingly when slicing. - self._outliers = 1 if -1 in set(labels) else 0 - # Extract probabilities probabilities = None if hasattr(self.hdbscan_model, "probabilities_"): @@ -4024,207 +4063,101 @@ def _zeroshot_topic_modeling( assigned_documents["ID"] = range(len(assigned_documents)) assigned_embeddings = embeddings[assigned_ids] + # Check that if a number of topics was specified, it exceeds the number of zeroshot topics matched + num_zeroshot_topics = len(assigned_documents["Topic"].unique()) + if self.nr_topics and not self.nr_topics > num_zeroshot_topics: + raise ValueError( + f"The set nr_topics ({self.nr_topics}) must exceed the number of matched zero-shot topics " + f"({num_zeroshot_topics}). Consider raising nr_topics or raising the " + f"zeroshot_min_similarity ({self.zeroshot_min_similarity})." + ) + # Select non-assigned topics to be clustered documents = documents.iloc[non_assigned_ids] documents["Old_ID"] = documents["ID"].copy() documents["ID"] = range(len(documents)) embeddings = embeddings[non_assigned_ids] - # If only matches were found - if len(non_assigned_ids) == 0: - return None, None, assigned_documents, assigned_embeddings logger.info("Zeroshot Step 1 - Completed \u2713") return documents, embeddings, assigned_documents, assigned_embeddings def _is_zeroshot(self): """Check whether zero-shot topic modeling is possible. - * There should be a cluster model used * Embedding model is necessary to convert zero-shot topics to embeddings * Zero-shot topics should be defined """ - if ( - self.zeroshot_topic_list is not None - and self.embedding_model is not None - and type(self.hdbscan_model) != BaseCluster - ): + if self.zeroshot_topic_list is not None and self.embedding_model is not None: return True return False def _combine_zeroshot_topics( self, documents: pd.DataFrame, - assigned_documents: pd.DataFrame, embeddings: np.ndarray, - ) -> Union[Tuple[np.ndarray, np.ndarray], np.ndarray]: + assigned_documents: pd.DataFrame, + assigned_embeddings: np.ndarray, + ) -> Tuple[pd.DataFrame, np.ndarray]: """Combine the zero-shot topics with the clustered topics. - There are three cases considered: - * Only zero-shot topics were found which will only return the zero-shot topic model - * Only clustered topics were found which will only return the clustered topic model - * Both zero-shot and clustered topics were found which will return a merged model - * This merged model is created using the `merge_models` function which will ignore - the underlying UMAP and HDBSCAN models + The zero-shot topics will be inserted between the outlier topic (that may or may not exist) and the rest of the + topics from clustering. The rest of the topics from clustering will be given new IDs to correspond to topics + after zero-shot topics. + + Documents and embeddings used in zero-shot topic modeling and clustering and re-merged. Arguments: - documents: Dataframe with documents and their corresponding IDs - assigned_documents: Dataframe with documents and their corresponding IDs + documents: DataFrame with clustered documents and their corresponding IDs + embeddings: The document embeddings for clustered documents + assigned_documents: DataFrame with documents and their corresponding IDs that were assigned to a zero-shot topic - embeddings: The document embeddings + assigned_embeddings: The document embeddings for documents that were assigned to a zero-shot topic Returns: - topics: The topics for each document - probabilities: The probabilities for each document + documents: DataFrame with all the original documents with their topic assignments + embeddings: np.ndarray of embeddings aligned with the documents """ logger.info( - "Zeroshot Step 2 - Clustering documents that were not found in the zero-shot model..." - ) - - # Fit BERTopic without actually performing any clustering - docs = assigned_documents.Document.tolist() - y = assigned_documents.Topic.tolist() - empty_dimensionality_model = BaseDimensionalityReduction() - empty_cluster_model = BaseCluster() - zeroshot_model = BERTopic( - n_gram_range=self.n_gram_range, - low_memory=self.low_memory, - calculate_probabilities=self.calculate_probabilities, - embedding_model=self.embedding_model, - umap_model=empty_dimensionality_model, - hdbscan_model=empty_cluster_model, - vectorizer_model=self.vectorizer_model, - ctfidf_model=self.ctfidf_model, - representation_model=self.representation_model, - verbose=self.verbose, - ).fit(docs, embeddings=embeddings, y=y) - logger.info("Zeroshot Step 2 - Completed \u2713") - logger.info( - "Zeroshot Step 3 - Combining clustered topics with the zeroshot model" + "Zeroshot Step 2 - Combining topics from zero-shot topic modeling with topics from clustering..." ) - - # Update model - self.umap_model = BaseDimensionalityReduction() - self.hdbscan_model = BaseCluster() - - # Update topic label - assigned_topics = assigned_documents.groupby("Topic").first().reset_index() - indices, topics = assigned_topics.ID.values, assigned_topics.Topic.values - labels = [ - zeroshot_model.topic_labels_[zeroshot_model.topics_[index]] - for index in indices - ] - labels = { - label: self.zeroshot_topic_list[topic] - for label, topic in zip(labels, topics) + # Combine Zero-shot topics with topics from clustering + zeroshot_topic_idx_to_topic_id = { + zeroshot_topic_id: new_topic_id + for new_topic_id, zeroshot_topic_id in enumerate( + set(assigned_documents.Topic) + ) } - - # If only zero-shot matches were found and clustering was not performed - if documents is None: - for topic in range(len(set(y))): - if zeroshot_model.topic_labels_.get(topic): - if labels.get(zeroshot_model.topic_labels_[topic]): - zeroshot_model.topic_labels_[topic] = labels[ - zeroshot_model.topic_labels_[topic] - ] - self.__dict__.clear() - self.__dict__.update(zeroshot_model.__dict__) - return self.topics_, self.probabilities_ - - # Merge the two topic models - merged_model = BERTopic.merge_models([zeroshot_model, self], min_similarity=1) - - # Update topic labels and representative docs of the zero-shot model - for topic in range(len(set(y))): - if merged_model.topic_labels_.get(topic): - if labels.get(merged_model.topic_labels_[topic]): - label = labels[merged_model.topic_labels_[topic]] - merged_model.topic_labels_[topic] = label - merged_model.representative_docs_[topic] = ( - zeroshot_model.representative_docs_[topic] - ) - - # Add representative docs of the clustered model - for topic in set(self.topics_): - merged_model.representative_docs_[topic + self._outliers + len(set(y))] = ( - self.representative_docs_[topic] + self._topic_id_to_zeroshot_topic_idx = { + new_topic_id: zeroshot_topic_id + for new_topic_id, zeroshot_topic_id in enumerate( + set(assigned_documents.Topic) ) + } + assigned_documents.Topic = assigned_documents.Topic.map( + zeroshot_topic_idx_to_topic_id + ) + num_zeroshot_topics = len(zeroshot_topic_idx_to_topic_id) - if self._outliers and merged_model.topic_sizes_.get(-1): - merged_model.topic_sizes_[len(set(y))] = merged_model.topic_sizes_[-1] - del merged_model.topic_sizes_[-1] - - # Update topic assignment by finding the documents with the - # correct updated topics - zeroshot_indices = list(assigned_documents.Old_ID.values) - zeroshot_topics = [ - self.zeroshot_topic_list[topic] for topic in assigned_documents.Topic.values - ] - - cluster_indices = list(documents.Old_ID.values) - cluster_names = list(merged_model.topic_labels_.values())[len(set(y)) :] - if self._outliers: - cluster_topics = [ - cluster_names[topic] if topic != -1 else "Outliers" - for topic in documents.Topic.values - ] - else: - cluster_topics = [cluster_names[topic] for topic in documents.Topic.values] - - df = pd.DataFrame( - { - "Indices": zeroshot_indices + cluster_indices, - "Label": zeroshot_topics + cluster_topics, - } - ).sort_values("Indices") - reverse_topic_labels = dict( - (v, k) for k, v in merged_model.topic_labels_.items() + # Insert zeroshot topics between outlier cluster and other clusters + documents.Topic = documents.Topic.apply( + lambda topic_id: topic_id + num_zeroshot_topics + if topic_id != -1 + else topic_id ) - if self._outliers: - reverse_topic_labels["Outliers"] = -1 - df.Label = df.Label.map(reverse_topic_labels) - merged_model.topics_ = df.Label.astype(int).tolist() - - # Update the class internally - has_outliers = bool(self._outliers) - self.__dict__.clear() - self.__dict__.update(merged_model.__dict__) - logger.info("Zeroshot Step 3 - Completed \u2713") - - # Move -1 topic back to position 0 if it exists - if has_outliers: - nr_zeroshot_topics = len(set(y)) - - # Re-map the topics such that the -1 topic is at position 0 - new_mappings = {} - for topic in self.topics_: - if topic < nr_zeroshot_topics: - new_mappings[topic] = topic - elif topic == nr_zeroshot_topics: - new_mappings[topic] = -1 - else: - new_mappings[topic] = topic - 1 - # Re-map the topics including all representations (labels, sizes, embeddings, etc.) - self.topics_ = [new_mappings[topic] for topic in self.topics_] - self.topic_representations_ = { - new_mappings[topic]: repr - for topic, repr in self.topic_representations_.items() - } - self.topic_labels_ = { - new_mappings[topic]: label - for topic, label in self.topic_labels_.items() - } - self.topic_sizes_ = collections.Counter(self.topics_) - self.topic_embeddings_ = np.vstack( - [ - self.topic_embeddings_[nr_zeroshot_topics], - self.topic_embeddings_[:nr_zeroshot_topics], - self.topic_embeddings_[nr_zeroshot_topics + 1 :], - ] - ) - self._outliers = 1 + # Combine the clustered documents/embeddings with assigned documents/embeddings in the original order + documents = pd.concat([documents, assigned_documents]) + embeddings = np.vstack([embeddings, assigned_embeddings]) + sorted_indices = documents.Old_ID.argsort() + documents = documents.iloc[sorted_indices] + embeddings = embeddings[sorted_indices] + + # Update topic sizes and topic mapper + self._update_topic_size(documents) + self.topic_mapper_ = TopicMapper(self.topics_) - return self.topics_ + logger.info("Zeroshot Step 2 - Completed \u2713") + return documents, embeddings def _guided_topic_modeling( self, embeddings: np.ndarray @@ -4304,10 +4237,6 @@ def _extract_topics( self._create_topic_vectors( documents=documents, embeddings=embeddings, mappings=mappings ) - self.topic_labels_ = { - key: f"{key}_" + "_".join([word[0] for word in values[:4]]) - for key, values in self.topic_representations_.items() - } if verbose: logger.info("Representation - Completed \u2713") @@ -4460,15 +4389,15 @@ def _create_topic_vectors( # Topic embeddings when merging topics elif self.topic_embeddings_ is not None and mappings is not None: topic_embeddings_dict = {} - for topic_from, topics_to in mappings.items(): - topic_ids = topics_to["topics_to"] - topic_sizes = topics_to["topic_sizes"] + for topic_to, topics_from in mappings.items(): + topic_ids = topics_from["topics_from"] + topic_sizes = topics_from["topic_sizes"] if topic_ids: embds = np.array(self.topic_embeddings_)[ np.array(topic_ids) + self._outliers ] topic_embedding = np.average(embds, axis=0, weights=topic_sizes) - topic_embeddings_dict[topic_from] = topic_embedding + topic_embeddings_dict[topic_to] = topic_embedding # Re-order topic embeddings topics_to_map = { @@ -4787,15 +4716,15 @@ def _reduce_to_n_topics( mapped_topics = { from_topic: to_topic for from_topic, to_topic in zip(topics, new_topics) } - mappings = defaultdict(list) + basic_mappings = defaultdict(list) for key, val in sorted(mapped_topics.items()): - mappings[val].append(key) + basic_mappings[val].append(key) mappings = { - topic_from: { - "topics_to": topics_to, - "topic_sizes": [self.topic_sizes_[topic] for topic in topics_to], + topic_to: { + "topics_from": topics_from, + "topic_sizes": [self.topic_sizes_[topic] for topic in topics_from], } - for topic_from, topics_to in mappings.items() + for topic_to, topics_from in basic_mappings.items() } # Map topics @@ -4806,6 +4735,52 @@ def _reduce_to_n_topics( # Update representations documents = self._sort_mappings_by_frequency(documents) self._extract_topics(documents, mappings=mappings) + + # When zero-shot topic(s) are present in the topics to merge, + # determine whether to take one of the zero-shot topic labels + # or use a calculated representation. + if self._is_zeroshot(): + new_topic_id_to_zeroshot_topic_idx = {} + topics_to_map = { + topic_mapping[0]: topic_mapping[1] + for topic_mapping in np.array(self.topic_mapper_.mappings_)[:, -2:] + } + + for topic_to, topics_from in basic_mappings.items(): + # When extracting topics, the reduced topics were reordered. + # Must get the updated topic_to. + topic_to = topics_to_map[topic_to] + + # which of the original topics are zero-shot + zeroshot_topic_ids = [ + topic_id + for topic_id in topics_from + if topic_id in self._topic_id_to_zeroshot_topic_idx + ] + if len(zeroshot_topic_ids) == 0: + continue + + # If any of the original topics are zero-shot, take the best fitting zero-shot label + # if the cosine similarity with the new topic exceeds the zero-shot threshold + zeroshot_labels = [ + self.zeroshot_topic_list[ + self._topic_id_to_zeroshot_topic_idx[topic_id] + ] + for topic_id in zeroshot_topic_ids + ] + zeroshot_embeddings = self._extract_embeddings(zeroshot_labels) + cosine_similarities = cosine_similarity( + zeroshot_embeddings, [self.topic_embeddings_[topic_to]] + ).flatten() + best_zeroshot_topic_idx = np.argmax(cosine_similarities) + best_cosine_similarity = cosine_similarities[best_zeroshot_topic_idx] + if best_cosine_similarity >= self.zeroshot_min_similarity: + new_topic_id_to_zeroshot_topic_idx[topic_to] = zeroshot_topic_ids[ + best_zeroshot_topic_idx + ] + + self._topic_id_to_zeroshot_topic_idx = new_topic_id_to_zeroshot_topic_idx + self._update_topic_size(documents) return documents @@ -5205,11 +5180,7 @@ def _create_model_from_files( topic_model.topic_sizes_ = { int(key): val for key, val in topics["topic_sizes"].items() } - topic_model.topic_labels_ = { - int(key): val for key, val in topics["topic_labels"].items() - } topic_model.custom_labels_ = topics["custom_labels"] - topic_model._outliers = topics["_outliers"] if topics.get("topic_aspects"): topic_aspects = {} diff --git a/tests/conftest.py b/tests/conftest.py index 610c8690..95bcf738 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -64,7 +64,7 @@ def zeroshot_topic_model(documents, document_embeddings, embedding_model): embedding_model=embedding_model, calculate_probabilities=True, zeroshot_topic_list=zeroshot_topic_list, - zeroshot_min_similarity=0.5, + zeroshot_min_similarity=0.3, ) model.umap_model.random_state = 42 model.hdbscan_model.min_cluster_size = 2 diff --git a/tests/test_bertopic.py b/tests/test_bertopic.py index a899680b..73614e1b 100644 --- a/tests/test_bertopic.py +++ b/tests/test_bertopic.py @@ -72,6 +72,17 @@ def test_full_model(model, documents, request): assert len(topics_test) == 2 + # Test zero-shot topic modeling + if topic_model._is_zeroshot(): + if topic_model._outliers: + assert set(topic_model.topic_labels_.keys()) == set( + range(-1, len(topic_model.topic_labels_) - 1) + ) + else: + assert set(topic_model.topic_labels_.keys()) == set( + range(len(topic_model.topic_labels_)) + ) + # Test topics over time timestamps = [i % 10 for i in range(len(documents))] topics_over_time = topic_model.topics_over_time(documents, timestamps)