From c3b2f70f09eb55fba303f781887528c707a4b803 Mon Sep 17 00:00:00 2001 From: nik Date: Thu, 31 Aug 2023 18:00:02 +0100 Subject: [PATCH] Add support for pd.DataFrame labeling --- label_studio_sdk/client.py | 2 +- label_studio_sdk/project.py | 65 ++++++++++++++++++++++++++++++++++++- label_studio_sdk/utils.py | 38 ++++++++++++++++++++++ 3 files changed, 103 insertions(+), 2 deletions(-) diff --git a/label_studio_sdk/client.py b/label_studio_sdk/client.py index 5a40f4c8..3e751a8f 100644 --- a/label_studio_sdk/client.py +++ b/label_studio_sdk/client.py @@ -408,7 +408,7 @@ def make_request(self, method, url, *args, **kwargs): except: content = response.text - logger.error( + logger.debug( f'\n--------------------------------------------\n' f'Request URL: {response.url}\n' f'Response status code: {response.status_code}\n' diff --git a/label_studio_sdk/project.py b/label_studio_sdk/project.py index 9cb10923..e3eb24c3 100644 --- a/label_studio_sdk/project.py +++ b/label_studio_sdk/project.py @@ -3,6 +3,7 @@ import os import json import logging +import pandas as pd import pathlib import time @@ -474,7 +475,7 @@ def get_from_id(cls, client, project_id) -> "Project": project.update_params() return project - def import_tasks(self, tasks, preannotated_from_fields: List = None): + def import_tasks(self, tasks, preannotated_from_fields: Optional[List] = None): """Import JSON-formatted labeling tasks. Tasks can be unlabeled or contain predictions. Parameters @@ -2129,3 +2130,65 @@ def delete_all_tasks(self, excluded_ids: list = None) -> Response: return self.make_request( "POST", f"/api/dm/actions?project={self.id}&id=delete_tasks", json=payload ) + + def get_dataframe(self): + labeled_tasks = self.get_labeled_tasks() + if not labeled_tasks: + return pd.DataFrame() + records = [] + for labeled_task in labeled_tasks: + annotation = labeled_task['annotations'][0] + label = annotation['result'][0]['value']['choices'][0] + if 'predictions' in labeled_task: + prediction = labeled_task['predictions'][0] + prediction_label = prediction['result'][0]['value']['choices'][0] + else: + prediction_label = None + data = labeled_task['data'] + data.update({'ground_truth': label, 'predictions': prediction_label}) + records.append(data) + return pd.DataFrame.from_records(records) + + def label_dataframe( + self, + df: pd.DataFrame, + inplace=False, + polling_interval=10, + output_column='ground_truth', + preannotated_from_fields=None, + ): + + if not inplace: + df = df.copy() + + tasks = df.reset_index().to_dict(orient='records') + num_tasks = len(tasks) + unlabeled_indices = set(df.index) + + self.import_tasks(tasks, preannotated_from_fields=preannotated_from_fields) + + while True: + try: + labeled_tasks = self.get_labeled_tasks() + labeled_tasks = [task for task in labeled_tasks if task['data']['index'] in unlabeled_indices] + + if len(labeled_tasks) >= num_tasks: + logger.info(f'All {len(unlabeled_indices)} tasks have been annotated.') + break + + logger.info( + f'Waiting for the user to annotate the tasks. ' + f'{len(labeled_tasks)} out of {num_tasks} tasks have been annotated...') + time.sleep(polling_interval) + except Exception as e: + logger.error(f"Error while fetching labeled tasks: {e}") + time.sleep(polling_interval) + + for labeled_task in labeled_tasks: + index = labeled_task['data']['index'] + annotation = labeled_task['annotations'][0] + # TODO: support other types of tasks + label = annotation['result'][0]['value']['choices'][0] + df.loc[index, output_column] = label + + return df diff --git a/label_studio_sdk/utils.py b/label_studio_sdk/utils.py index 6a507e4c..f7dd4a94 100644 --- a/label_studio_sdk/utils.py +++ b/label_studio_sdk/utils.py @@ -1,9 +1,12 @@ """ .. include::../docs/utils.md """ import logging +import time +from datetime import datetime from lxml import etree from collections import defaultdict +from typing import Optional logger = logging.getLogger(__name__) @@ -130,3 +133,38 @@ def chunk(lst, n): """Yield successive n-sized chunks from lst.""" for i in range(0, len(lst), n): yield lst[i : i + n] + + +def get_or_create_project( + project_id: Optional[int] = None, + title: Optional[str] = None, + url: Optional[str] = None, + api_key: Optional[str] = None +): + """Get or create a Label Studio project. + + Parameters: + ----------- + project_id: int + ID of the Label Studio project to get or create if it doesn't exist + title: str + Title of the Label Studio project (if not provided, will be set to the current timestamp) + url: str + URL of the Label Studio instance (environment variable `LABEL_STUDIO_URL` if not provided) + api_key: str + API key for the Label Studio instance (environment variable `LABEL_STUDIO_API_KEY` if not provided) + """ + + from .client import Client + + client = Client(url=url, api_key=api_key) + if project_id is None: + title = title or f'SDK [{datetime.now().strftime("%Y-%m-%d %H:%M:%S")}]' + existing_projects = client.get_projects(title=title) + if existing_projects: + project = existing_projects[0] + else: + project = client.create_project(title=title) + return project + project = client.get_project(id=project_id) + return project