Skip to content

Commit

Permalink
Add support for pd.DataFrame labeling
Browse files Browse the repository at this point in the history
  • Loading branch information
nik committed Aug 31, 2023
1 parent 900aae5 commit c3b2f70
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 2 deletions.
2 changes: 1 addition & 1 deletion label_studio_sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ def make_request(self, method, url, *args, **kwargs):
except:
content = response.text

logger.error(
logger.debug(
f'\n--------------------------------------------\n'
f'Request URL: {response.url}\n'
f'Response status code: {response.status_code}\n'
Expand Down
65 changes: 64 additions & 1 deletion label_studio_sdk/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import json
import logging
import pandas as pd
import pathlib
import time

Expand Down Expand Up @@ -474,7 +475,7 @@ def get_from_id(cls, client, project_id) -> "Project":
project.update_params()
return project

def import_tasks(self, tasks, preannotated_from_fields: List = None):
def import_tasks(self, tasks, preannotated_from_fields: Optional[List] = None):
"""Import JSON-formatted labeling tasks. Tasks can be unlabeled or contain predictions.
Parameters
Expand Down Expand Up @@ -2129,3 +2130,65 @@ def delete_all_tasks(self, excluded_ids: list = None) -> Response:
return self.make_request(
"POST", f"/api/dm/actions?project={self.id}&id=delete_tasks", json=payload
)

def get_dataframe(self):
labeled_tasks = self.get_labeled_tasks()
if not labeled_tasks:
return pd.DataFrame()
records = []
for labeled_task in labeled_tasks:
annotation = labeled_task['annotations'][0]
label = annotation['result'][0]['value']['choices'][0]
if 'predictions' in labeled_task:
prediction = labeled_task['predictions'][0]
prediction_label = prediction['result'][0]['value']['choices'][0]
else:
prediction_label = None
data = labeled_task['data']
data.update({'ground_truth': label, 'predictions': prediction_label})
records.append(data)
return pd.DataFrame.from_records(records)

def label_dataframe(
self,
df: pd.DataFrame,
inplace=False,
polling_interval=10,
output_column='ground_truth',
preannotated_from_fields=None,
):

if not inplace:
df = df.copy()

tasks = df.reset_index().to_dict(orient='records')
num_tasks = len(tasks)
unlabeled_indices = set(df.index)

self.import_tasks(tasks, preannotated_from_fields=preannotated_from_fields)

while True:
try:
labeled_tasks = self.get_labeled_tasks()
labeled_tasks = [task for task in labeled_tasks if task['data']['index'] in unlabeled_indices]

if len(labeled_tasks) >= num_tasks:
logger.info(f'All {len(unlabeled_indices)} tasks have been annotated.')
break

logger.info(
f'Waiting for the user to annotate the tasks. '
f'{len(labeled_tasks)} out of {num_tasks} tasks have been annotated...')
time.sleep(polling_interval)
except Exception as e:
logger.error(f"Error while fetching labeled tasks: {e}")
time.sleep(polling_interval)

for labeled_task in labeled_tasks:
index = labeled_task['data']['index']
annotation = labeled_task['annotations'][0]
# TODO: support other types of tasks
label = annotation['result'][0]['value']['choices'][0]
df.loc[index, output_column] = label

return df
38 changes: 38 additions & 0 deletions label_studio_sdk/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
""" .. include::../docs/utils.md
"""
import logging
import time

from datetime import datetime
from lxml import etree
from collections import defaultdict
from typing import Optional

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -130,3 +133,38 @@ def chunk(lst, n):
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield lst[i : i + n]


def get_or_create_project(
project_id: Optional[int] = None,
title: Optional[str] = None,
url: Optional[str] = None,
api_key: Optional[str] = None
):
"""Get or create a Label Studio project.
Parameters:
-----------
project_id: int
ID of the Label Studio project to get or create if it doesn't exist
title: str
Title of the Label Studio project (if not provided, will be set to the current timestamp)
url: str
URL of the Label Studio instance (environment variable `LABEL_STUDIO_URL` if not provided)
api_key: str
API key for the Label Studio instance (environment variable `LABEL_STUDIO_API_KEY` if not provided)
"""

from .client import Client

client = Client(url=url, api_key=api_key)
if project_id is None:
title = title or f'SDK [{datetime.now().strftime("%Y-%m-%d %H:%M:%S")}]'
existing_projects = client.get_projects(title=title)
if existing_projects:
project = existing_projects[0]
else:
project = client.create_project(title=title)
return project
project = client.get_project(id=project_id)
return project

0 comments on commit c3b2f70

Please sign in to comment.