Skip to content

Commit

Permalink
Added support for pandas Styler object to gr.DataFrame (initially…
Browse files Browse the repository at this point in the history
… just sets the `display_value`) (#5569)

* adding precision to df

* add changeset

* docstring

* precision

* add changeset

* fix

* fixes

* add changeset

* add visual test

* lint

* fixes

* lint

* format

* add changeset

* ts changes

* analytics

* dataframe typing

* typing

* demo

* fix

* lint

* interactive dataframe

* dataframe

* fix typing

* add test

* upgrade pandas version

* fix pandas version

---------

Co-authored-by: gradio-pr-bot <[email protected]>
  • Loading branch information
abidlabs and gradio-pr-bot authored Oct 2, 2023
1 parent caf6d9c commit 2a5b9e0
Show file tree
Hide file tree
Showing 10 changed files with 245 additions and 82 deletions.
6 changes: 6 additions & 0 deletions .changeset/wicked-badgers-smash.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"@gradio/dataframe": minor
"gradio": minor
---

feat:Added support for pandas `Styler` object to `gr.DataFrame` (initially just sets the `display_value`)
129 changes: 92 additions & 37 deletions gradio/components/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@
from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Any, Callable, Literal
from dataclasses import asdict, dataclass
from typing import Callable, Literal

import numpy as np
import pandas as pd
import semantic_version
from gradio_client.documentation import document, set_documentation_group
from gradio_client.serializing import JSONSerializable
from pandas.io.formats.style import Styler

from gradio.components.base import IOComponent, _Keywords
from gradio.events import (
Expand All @@ -18,30 +21,41 @@
Selectable,
)

if TYPE_CHECKING:
from typing import TypedDict
set_documentation_group("component")

class DataframeData(TypedDict):
headers: list[str]
data: list[list[str | int | bool]]

@dataclass
class DataframeData:
"""
This is a dataclass to represent all the data that is sent to or received from the frontend.
"""

set_documentation_group("component")
data: list[list[str | int | bool]]
headers: list[str] | list[int] | None = None
metadata: dict[str, list[list]] | None = None


@document()
class Dataframe(Changeable, Inputable, Selectable, IOComponent, JSONSerializable):
"""
Accepts or displays 2D input through a spreadsheet-like component for dataframes.
Preprocessing: passes the uploaded spreadsheet data as a {pandas.DataFrame}, {numpy.array}, {List[List]}, or {List} depending on `type`
Postprocessing: expects a {pandas.DataFrame}, {numpy.array}, {List[List]}, {List}, a {Dict} with keys `data` (and optionally `headers`), or {str} path to a csv, which is rendered in the spreadsheet.
Preprocessing: passes the uploaded spreadsheet data as a {pandas.DataFrame}, {numpy.array}, or {List[List]} depending on `type`
Postprocessing: expects a {pandas.DataFrame}, {pandas.Styler}, {numpy.array}, {List[List]}, {List}, a {Dict} with keys `data` (and optionally `headers`), or {str} path to a csv, which is rendered in the spreadsheet.
Examples-format: a {str} filepath to a csv with data, a pandas dataframe, or a list of lists (excluding headers) where each sublist is a row of data.
Demos: filter_records, matrix_transpose, tax_calculator
"""

def __init__(
self,
value: list[list[Any]] | Callable | None = None,
value: pd.DataFrame
| Styler
| np.ndarray
| list
| list[list]
| dict
| str
| Callable
| None = None,
*,
headers: list[str] | None = None,
row_count: int | tuple[int, str] = (1, "dynamic"),
Expand All @@ -67,12 +81,12 @@ def __init__(
):
"""
Parameters:
value: Default value as a 2-dimensional list of values. If callable, the function will be called whenever the app loads to set the initial value of the component.
value: Default value to display in the DataFrame. If a Styler is provided, it will be used to set the displayed value in the DataFrame (e.g. to set precision of numbers) if the `interactive` is False. If a Callable function is provided, the function will be called whenever the app loads to set the initial value of the component.
headers: List of str header names. If None, no headers are shown.
row_count: Limit number of rows for input and decide whether user can create new rows. The first element of the tuple is an `int`, the row count; the second should be 'fixed' or 'dynamic', the new row behaviour. If an `int` is passed the rows default to 'dynamic'
col_count: Limit number of columns for input and decide whether user can create new columns. The first element of the tuple is an `int`, the number of columns; the second should be 'fixed' or 'dynamic', the new column behaviour. If an `int` is passed the columns default to 'dynamic'
datatype: Datatype of values in sheet. Can be provided per column as a list of strings, or for the entire sheet as a single string. Valid datatypes are "str", "number", "bool", "date", and "markdown".
type: Type of value to be returned by component. "pandas" for pandas dataframe, "numpy" for numpy array, or "array" for a Python array.
type: Type of value to be returned by component. "pandas" for pandas dataframe, "numpy" for numpy array, or "array" for a Python list of lists.
label: component name in interface.
max_rows: Deprecated and has no effect. Use `row_count` instead.
max_cols: Deprecated and has no effect. Use `col_count` instead.
Expand Down Expand Up @@ -157,7 +171,15 @@ def __init__(

@staticmethod
def update(
value: Any | Literal[_Keywords.NO_VALUE] | None = _Keywords.NO_VALUE,
value: pd.DataFrame
| Styler
| np.ndarray
| list
| list[list]
| dict
| str
| Literal[_Keywords.NO_VALUE]
| None = _Keywords.NO_VALUE,
max_rows: int | None = None,
max_cols: str | None = None,
label: str | None = None,
Expand Down Expand Up @@ -187,22 +209,23 @@ def update(
"__type__": "update",
}

def preprocess(self, x: DataframeData):
def preprocess(self, x: dict) -> pd.DataFrame | np.ndarray | list:
"""
Parameters:
x: 2D array of str, numeric, or bool data
x: Dictionary equivalent of DataframeData containing `headers`, `data`, and optionally `metadata` keys
Returns:
Dataframe in requested format
The Dataframe data in requested format
"""
value = DataframeData(**x)
if self.type == "pandas":
if x.get("headers") is not None:
return pd.DataFrame(x["data"], columns=x.get("headers"))
if value.headers is not None:
return pd.DataFrame(value.data, columns=value.headers)
else:
return pd.DataFrame(x["data"])
return pd.DataFrame(value.data)
if self.type == "numpy":
return np.array(x["data"])
return np.array(value.data)
elif self.type == "array":
return x["data"]
return value.data
else:
raise ValueError(
"Unknown type: "
Expand All @@ -211,7 +234,8 @@ def preprocess(self, x: DataframeData):
)

def postprocess(
self, y: str | pd.DataFrame | np.ndarray | list[list[str | float]] | dict
self,
y: pd.DataFrame | Styler | np.ndarray | list | list[list] | dict | str | None,
) -> dict:
"""
Parameters:
Expand All @@ -222,23 +246,38 @@ def postprocess(
if y is None:
return self.postprocess(self.empty_input)
if isinstance(y, dict):
return y
if isinstance(y, (str, pd.DataFrame)):
if isinstance(y, str):
y = pd.read_csv(y)
return {
"headers": list(y.columns), # type: ignore
"data": y.to_dict(orient="split")["data"], # type: ignore
}
if isinstance(y, (np.ndarray, list)):
value = DataframeData(**y)
elif isinstance(y, Styler):
if semantic_version.Version(pd.__version__) < semantic_version.Version(
"1.5.0"
):
raise ValueError(
"Styler objects are only supported in pandas version 1.5.0 or higher. Please try: `pip install --upgrade pandas` to use this feature."
)
if self.interactive:
warnings.warn(
"Cannot display Styler object in interactive mode. Will display as a regular pandas dataframe instead."
)
df: pd.DataFrame = y.data # type: ignore
value = DataframeData(
headers=list(df.columns),
data=df.to_dict(orient="split")["data"],
metadata=self.__extract_metadata(y),
)
elif isinstance(y, (str, pd.DataFrame)):
df = pd.read_csv(y) if isinstance(y, str) else y
value = DataframeData(
headers=list(df.columns),
data=df.to_dict(orient="split")["data"],
)
elif isinstance(y, (np.ndarray, list)):
if len(y) == 0:
return self.postprocess([[]])
if isinstance(y, np.ndarray):
y = y.tolist()
assert isinstance(y, list), "output cannot be converted to list"

_headers = self.headers

if len(self.headers) < len(y[0]):
_headers = [
*self.headers,
Expand All @@ -247,11 +286,27 @@ def postprocess(
elif len(self.headers) > len(y[0]):
_headers = self.headers[: len(y[0])]

return {
"headers": _headers,
"data": y,
}
raise ValueError("Cannot process value as a Dataframe")
value = DataframeData(
headers=_headers,
data=y,
)
else:
raise ValueError(f"Cannot process value as a Dataframe: {y}")
return asdict(value)

@staticmethod
def __extract_metadata(df: Styler) -> dict[str, list[list]]:
metadata = {"display_value": []}
style_data = df._compute()._translate(None, None) # type: ignore
for i in range(len(style_data["body"])):
metadata["display_value"].append([])
for j in range(len(style_data["body"][i])):
cell_type = style_data["body"][i][j]["type"]
if cell_type != "td":
continue
display_value = style_data["body"][i][j]["display_value"]
metadata["display_value"][i].append(display_value)
return metadata

@staticmethod
def __process_counts(count, default=3) -> tuple[int, str]:
Expand Down
23 changes: 23 additions & 0 deletions js/dataframe/Dataframe.stories.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,29 @@
}}
/>

<Story
name="Dataframe with different precisions"
args={{
values: {
data: [
[1.24, 1.24, 1.24],
[1.21, 1.21, 1.21]
],
metadata: {
display_value: [
["1", "1.2", "1.24"],
["1", "1.2", "1.21"]
]
}
},
headers: ["Precision=0", "Precision=1", "Precision=2"],
label: "Animals",
col_count: [3, "dynamic"],
row_count: [2, "dynamic"],
editable: false
}}
/>

<Story
name="Dataframe with markdown and math"
args={{
Expand Down
16 changes: 8 additions & 8 deletions js/dataframe/interactive/InteractiveDataframe.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,16 @@
import type { LoadingStatus } from "@gradio/statustracker";
import { afterUpdate } from "svelte";
import { _ } from "svelte-i18n";
type Headers = string[];
type Data = (string | number)[][];
type Datatype = "str" | "markdown" | "html" | "number" | "bool" | "date";
import type { Headers, Data, Metadata, Datatype } from "../shared/utils";
export let headers: Headers = [];
export let elem_id = "";
export let elem_classes: string[] = [];
export let visible = true;
export let value: { data: Data; headers: Headers } = {
export let value: { data: Data; headers: Headers; metadata: Metadata } = {
data: [["", "", ""]],
headers: ["1", "2", "3"]
headers: ["1", "2", "3"],
metadata: null
};
export let latex_delimiters: {
left: string;
Expand Down Expand Up @@ -67,7 +65,9 @@
data: [Array(col_count?.[0] || 3).fill("")],
headers: Array(col_count?.[0] || 3)
.fill("")
.map((_, i) => `${i + 1}`)
.map((_, i) => `${i + 1}`),
metadata: null
};
}
</script>
Expand All @@ -87,7 +87,7 @@
{label}
{row_count}
{col_count}
values={value}
{value}
{headers}
on:change={({ detail }) => {
value = detail;
Expand Down
4 changes: 3 additions & 1 deletion js/dataframe/shared/EditableCell.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
export let edit: boolean;
export let value: string | number = "";
export let display_value: string | null = null;
export let header = false;
export let datatype:
| "str"
Expand All @@ -20,6 +21,7 @@
}[];
export let clear_on_focus = false;
export let select_on_focus = false;
export let editable = true;
const dispatch = createEventDispatcher();
Expand Down Expand Up @@ -71,7 +73,7 @@
chatbot={false}
/>
{:else}
{value}
{editable ? value : display_value || value}
{/if}
</span>

Expand Down
Loading

0 comments on commit 2a5b9e0

Please sign in to comment.