diff --git a/.changeset/wicked-badgers-smash.md b/.changeset/wicked-badgers-smash.md
new file mode 100644
index 000000000000..e26658d2f2c5
--- /dev/null
+++ b/.changeset/wicked-badgers-smash.md
@@ -0,0 +1,6 @@
+---
+"@gradio/dataframe": minor
+"gradio": minor
+---
+
+feat:Added support for pandas `Styler` object to `gr.DataFrame` (initially just sets the `display_value`)
diff --git a/gradio/components/dataframe.py b/gradio/components/dataframe.py
index b8981826a5d4..b893716dec83 100644
--- a/gradio/components/dataframe.py
+++ b/gradio/components/dataframe.py
@@ -3,12 +3,15 @@
from __future__ import annotations
import warnings
-from typing import TYPE_CHECKING, Any, Callable, Literal
+from dataclasses import asdict, dataclass
+from typing import Callable, Literal
import numpy as np
import pandas as pd
+import semantic_version
from gradio_client.documentation import document, set_documentation_group
from gradio_client.serializing import JSONSerializable
+from pandas.io.formats.style import Styler
from gradio.components.base import IOComponent, _Keywords
from gradio.events import (
@@ -18,30 +21,41 @@
Selectable,
)
-if TYPE_CHECKING:
- from typing import TypedDict
+set_documentation_group("component")
- class DataframeData(TypedDict):
- headers: list[str]
- data: list[list[str | int | bool]]
+@dataclass
+class DataframeData:
+ """
+ This is a dataclass to represent all the data that is sent to or received from the frontend.
+ """
-set_documentation_group("component")
+ data: list[list[str | int | bool]]
+ headers: list[str] | list[int] | None = None
+ metadata: dict[str, list[list]] | None = None
@document()
class Dataframe(Changeable, Inputable, Selectable, IOComponent, JSONSerializable):
"""
Accepts or displays 2D input through a spreadsheet-like component for dataframes.
- Preprocessing: passes the uploaded spreadsheet data as a {pandas.DataFrame}, {numpy.array}, {List[List]}, or {List} depending on `type`
- Postprocessing: expects a {pandas.DataFrame}, {numpy.array}, {List[List]}, {List}, a {Dict} with keys `data` (and optionally `headers`), or {str} path to a csv, which is rendered in the spreadsheet.
+ Preprocessing: passes the uploaded spreadsheet data as a {pandas.DataFrame}, {numpy.array}, or {List[List]} depending on `type`
+ Postprocessing: expects a {pandas.DataFrame}, {pandas.Styler}, {numpy.array}, {List[List]}, {List}, a {Dict} with keys `data` (and optionally `headers`), or {str} path to a csv, which is rendered in the spreadsheet.
Examples-format: a {str} filepath to a csv with data, a pandas dataframe, or a list of lists (excluding headers) where each sublist is a row of data.
Demos: filter_records, matrix_transpose, tax_calculator
"""
def __init__(
self,
- value: list[list[Any]] | Callable | None = None,
+ value: pd.DataFrame
+ | Styler
+ | np.ndarray
+ | list
+ | list[list]
+ | dict
+ | str
+ | Callable
+ | None = None,
*,
headers: list[str] | None = None,
row_count: int | tuple[int, str] = (1, "dynamic"),
@@ -67,12 +81,12 @@ def __init__(
):
"""
Parameters:
- value: Default value as a 2-dimensional list of values. If callable, the function will be called whenever the app loads to set the initial value of the component.
+ value: Default value to display in the DataFrame. If a Styler is provided, it will be used to set the displayed value in the DataFrame (e.g. to set precision of numbers) if the `interactive` is False. If a Callable function is provided, the function will be called whenever the app loads to set the initial value of the component.
headers: List of str header names. If None, no headers are shown.
row_count: Limit number of rows for input and decide whether user can create new rows. The first element of the tuple is an `int`, the row count; the second should be 'fixed' or 'dynamic', the new row behaviour. If an `int` is passed the rows default to 'dynamic'
col_count: Limit number of columns for input and decide whether user can create new columns. The first element of the tuple is an `int`, the number of columns; the second should be 'fixed' or 'dynamic', the new column behaviour. If an `int` is passed the columns default to 'dynamic'
datatype: Datatype of values in sheet. Can be provided per column as a list of strings, or for the entire sheet as a single string. Valid datatypes are "str", "number", "bool", "date", and "markdown".
- type: Type of value to be returned by component. "pandas" for pandas dataframe, "numpy" for numpy array, or "array" for a Python array.
+ type: Type of value to be returned by component. "pandas" for pandas dataframe, "numpy" for numpy array, or "array" for a Python list of lists.
label: component name in interface.
max_rows: Deprecated and has no effect. Use `row_count` instead.
max_cols: Deprecated and has no effect. Use `col_count` instead.
@@ -157,7 +171,15 @@ def __init__(
@staticmethod
def update(
- value: Any | Literal[_Keywords.NO_VALUE] | None = _Keywords.NO_VALUE,
+ value: pd.DataFrame
+ | Styler
+ | np.ndarray
+ | list
+ | list[list]
+ | dict
+ | str
+ | Literal[_Keywords.NO_VALUE]
+ | None = _Keywords.NO_VALUE,
max_rows: int | None = None,
max_cols: str | None = None,
label: str | None = None,
@@ -187,22 +209,23 @@ def update(
"__type__": "update",
}
- def preprocess(self, x: DataframeData):
+ def preprocess(self, x: dict) -> pd.DataFrame | np.ndarray | list:
"""
Parameters:
- x: 2D array of str, numeric, or bool data
+ x: Dictionary equivalent of DataframeData containing `headers`, `data`, and optionally `metadata` keys
Returns:
- Dataframe in requested format
+ The Dataframe data in requested format
"""
+ value = DataframeData(**x)
if self.type == "pandas":
- if x.get("headers") is not None:
- return pd.DataFrame(x["data"], columns=x.get("headers"))
+ if value.headers is not None:
+ return pd.DataFrame(value.data, columns=value.headers)
else:
- return pd.DataFrame(x["data"])
+ return pd.DataFrame(value.data)
if self.type == "numpy":
- return np.array(x["data"])
+ return np.array(value.data)
elif self.type == "array":
- return x["data"]
+ return value.data
else:
raise ValueError(
"Unknown type: "
@@ -211,7 +234,8 @@ def preprocess(self, x: DataframeData):
)
def postprocess(
- self, y: str | pd.DataFrame | np.ndarray | list[list[str | float]] | dict
+ self,
+ y: pd.DataFrame | Styler | np.ndarray | list | list[list] | dict | str | None,
) -> dict:
"""
Parameters:
@@ -222,15 +246,31 @@ def postprocess(
if y is None:
return self.postprocess(self.empty_input)
if isinstance(y, dict):
- return y
- if isinstance(y, (str, pd.DataFrame)):
- if isinstance(y, str):
- y = pd.read_csv(y)
- return {
- "headers": list(y.columns), # type: ignore
- "data": y.to_dict(orient="split")["data"], # type: ignore
- }
- if isinstance(y, (np.ndarray, list)):
+ value = DataframeData(**y)
+ elif isinstance(y, Styler):
+ if semantic_version.Version(pd.__version__) < semantic_version.Version(
+ "1.5.0"
+ ):
+ raise ValueError(
+ "Styler objects are only supported in pandas version 1.5.0 or higher. Please try: `pip install --upgrade pandas` to use this feature."
+ )
+ if self.interactive:
+ warnings.warn(
+ "Cannot display Styler object in interactive mode. Will display as a regular pandas dataframe instead."
+ )
+ df: pd.DataFrame = y.data # type: ignore
+ value = DataframeData(
+ headers=list(df.columns),
+ data=df.to_dict(orient="split")["data"],
+ metadata=self.__extract_metadata(y),
+ )
+ elif isinstance(y, (str, pd.DataFrame)):
+ df = pd.read_csv(y) if isinstance(y, str) else y
+ value = DataframeData(
+ headers=list(df.columns),
+ data=df.to_dict(orient="split")["data"],
+ )
+ elif isinstance(y, (np.ndarray, list)):
if len(y) == 0:
return self.postprocess([[]])
if isinstance(y, np.ndarray):
@@ -238,7 +278,6 @@ def postprocess(
assert isinstance(y, list), "output cannot be converted to list"
_headers = self.headers
-
if len(self.headers) < len(y[0]):
_headers = [
*self.headers,
@@ -247,11 +286,27 @@ def postprocess(
elif len(self.headers) > len(y[0]):
_headers = self.headers[: len(y[0])]
- return {
- "headers": _headers,
- "data": y,
- }
- raise ValueError("Cannot process value as a Dataframe")
+ value = DataframeData(
+ headers=_headers,
+ data=y,
+ )
+ else:
+ raise ValueError(f"Cannot process value as a Dataframe: {y}")
+ return asdict(value)
+
+ @staticmethod
+ def __extract_metadata(df: Styler) -> dict[str, list[list]]:
+ metadata = {"display_value": []}
+ style_data = df._compute()._translate(None, None) # type: ignore
+ for i in range(len(style_data["body"])):
+ metadata["display_value"].append([])
+ for j in range(len(style_data["body"][i])):
+ cell_type = style_data["body"][i][j]["type"]
+ if cell_type != "td":
+ continue
+ display_value = style_data["body"][i][j]["display_value"]
+ metadata["display_value"][i].append(display_value)
+ return metadata
@staticmethod
def __process_counts(count, default=3) -> tuple[int, str]:
diff --git a/js/dataframe/Dataframe.stories.svelte b/js/dataframe/Dataframe.stories.svelte
index ebe7229f3ebc..210c0b09d472 100644
--- a/js/dataframe/Dataframe.stories.svelte
+++ b/js/dataframe/Dataframe.stories.svelte
@@ -51,6 +51,29 @@
}}
/>
+
+
`${i + 1}`)
+ .map((_, i) => `${i + 1}`),
+
+ metadata: null
};
}
@@ -87,7 +87,7 @@
{label}
{row_count}
{col_count}
- values={value}
+ {value}
{headers}
on:change={({ detail }) => {
value = detail;
diff --git a/js/dataframe/shared/EditableCell.svelte b/js/dataframe/shared/EditableCell.svelte
index 14958d72e86a..87e69de9a60f 100644
--- a/js/dataframe/shared/EditableCell.svelte
+++ b/js/dataframe/shared/EditableCell.svelte
@@ -5,6 +5,7 @@
export let edit: boolean;
export let value: string | number = "";
+ export let display_value: string | null = null;
export let header = false;
export let datatype:
| "str"
@@ -20,6 +21,7 @@
}[];
export let clear_on_focus = false;
export let select_on_focus = false;
+ export let editable = true;
const dispatch = createEventDispatcher();
@@ -71,7 +73,7 @@
chatbot={false}
/>
{:else}
- {value}
+ {editable ? value : display_value || value}
{/if}
diff --git a/js/dataframe/shared/Table.svelte b/js/dataframe/shared/Table.svelte
index 055d6a3b4164..19d5b6fafb11 100644
--- a/js/dataframe/shared/Table.svelte
+++ b/js/dataframe/shared/Table.svelte
@@ -9,15 +9,19 @@
import type { SelectData } from "@gradio/utils";
import { _ } from "svelte-i18n";
import VirtualTable from "./VirtualTable.svelte";
-
- type Datatype = "str" | "markdown" | "html" | "number" | "bool" | "date";
+ import type {
+ Headers,
+ HeadersWithIDs,
+ Data,
+ Metadata,
+ Datatype
+ } from "../shared/utils";
export let datatype: Datatype | Datatype[];
export let label: string | null = null;
- export let headers: string[] = [];
- export let values:
- | (string | number)[][]
- | { data: (string | number)[][]; headers: string[] } = [[]];
+ export let headers: Headers = [];
+ let values: (string | number)[][];
+ export let value: { data: Data; headers: Headers; metadata: Metadata } | null;
export let col_count: [number, "fixed" | "dynamic"];
export let row_count: [number, "fixed" | "dynamic"];
export let latex_delimiters: {
@@ -30,18 +34,24 @@
export let wrap = false;
export let height = 500;
let selected: false | [number, number] = false;
+ let display_value: string[][] | null = value?.metadata?.display_value ?? null;
$: {
- if (values && !Array.isArray(values)) {
- headers = values.headers;
- values = values.data;
+ if (value) {
+ headers = value.headers;
+ values = value.data;
+ display_value = value?.metadata?.display_value ?? null;
} else if (values === null) {
values = [];
}
}
const dispatch = createEventDispatcher<{
- change: { data: (string | number)[][]; headers: string[] };
+ change: {
+ data: (string | number)[][];
+ headers: string[];
+ metadata: Metadata;
+ };
select: SelectData;
}>();
@@ -64,12 +74,10 @@
let data_binding: Record = {};
- type Headers = { value: string; id: string }[];
-
function make_id(): string {
return Math.random().toString(36).substring(2, 15);
}
- function make_headers(_head: string[]): Headers {
+ function make_headers(_head: Headers): HeadersWithIDs {
let _h = _head || [];
if (col_count[1] === "fixed" && _h.length < col_count[0]) {
const fill = Array(col_count[0] - _h.length)
@@ -152,7 +160,8 @@
$: _headers &&
dispatch("change", {
data: data.map((r) => r.map(({ value }) => value)),
- headers: _headers.map((h) => h.value)
+ headers: _headers.map((h) => h.value),
+ metadata: editable ? null : { display_value: display_value }
});
function get_sort_status(
@@ -545,6 +554,7 @@
function sort_data(
_data: typeof data,
+ _display_value: string[][] | null,
col?: number,
dir?: SortDirection
): void {
@@ -556,12 +566,29 @@
if (typeof col !== "number" || !dir) {
return;
}
+ const indices = [...Array(_data.length).keys()];
+
if (dir === "asc") {
- _data.sort((a, b) => (a[col].value < b[col].value ? -1 : 1));
+ indices.sort((i, j) =>
+ _data[i][col].value < _data[j][col].value ? -1 : 1
+ );
} else if (dir === "des") {
- _data.sort((a, b) => (a[col].value > b[col].value ? -1 : 1));
+ indices.sort((i, j) =>
+ _data[i][col].value > _data[j][col].value ? -1 : 1
+ );
+ } else {
+ return;
}
+ // sort both data and display_value in place based on the values in data
+ const tempData = [..._data];
+ const tempData2 = _display_value ? [..._display_value] : null;
+ indices.forEach((originalIndex, sortedIndex) => {
+ _data[sortedIndex] = tempData[originalIndex];
+ if (_display_value && tempData2)
+ _display_value[sortedIndex] = tempData2[originalIndex];
+ });
+
data = data;
if (id) {
@@ -570,7 +597,7 @@
}
}
- $: sort_data(data, sort_by, sort_direction);
+ $: sort_data(data, display_value, sort_by, sort_direction);
$: selected_index = !!selected && selected[0];
@@ -751,7 +778,9 @@
((clear_on_focus = false), parent.focus())}
diff --git a/js/dataframe/shared/utils.ts b/js/dataframe/shared/utils.ts
new file mode 100644
index 000000000000..09d4e0940d01
--- /dev/null
+++ b/js/dataframe/shared/utils.ts
@@ -0,0 +1,7 @@
+export type Headers = string[];
+export type Data = (string | number)[][];
+export type Datatype = "str" | "markdown" | "html" | "number" | "bool" | "date";
+export type Metadata = {
+ [key: string]: string[][] | null;
+} | null;
+export type HeadersWithIDs = { value: string; id: string }[];
diff --git a/js/dataframe/static/StaticDataframe.svelte b/js/dataframe/static/StaticDataframe.svelte
index 921bf394717d..dc2058ec24f2 100644
--- a/js/dataframe/static/StaticDataframe.svelte
+++ b/js/dataframe/static/StaticDataframe.svelte
@@ -5,18 +5,15 @@
import Table from "../shared";
import { StatusTracker } from "@gradio/statustracker";
import type { LoadingStatus } from "@gradio/statustracker";
-
- type Headers = string[];
- type Data = (string | number)[][];
- type Datatype = "str" | "markdown" | "html" | "number" | "bool" | "date";
-
+ import type { Headers, Data, Metadata, Datatype } from "../shared/utils";
export let headers: Headers = [];
export let elem_id = "";
export let elem_classes: string[] = [];
export let visible = true;
- export let value: { data: Data; headers: Headers } = {
+ export let value: { data: Data; headers: Headers; metadata: Metadata } = {
data: [["", "", ""]],
- headers: ["1", "2", "3"]
+ headers: ["1", "2", "3"],
+ metadata: null
};
let old_value: string = JSON.stringify(value);
export let value_is_output = false;
@@ -56,7 +53,6 @@
handle_change();
}
}
-
if (
(Array.isArray(value) && value?.[0]?.length === 0) ||
value.data?.[0]?.length === 0
@@ -65,7 +61,8 @@
data: [Array(col_count?.[0] || 3).fill("")],
headers: Array(col_count?.[0] || 3)
.fill("")
- .map((_, i) => `${i + 1}`)
+ .map((_, i) => `${i + 1}`),
+ metadata: null
};
}
@@ -85,11 +82,8 @@
{label}
{row_count}
{col_count}
- values={value}
+ {value}
{headers}
- on:change={({ detail }) => {
- value = detail;
- }}
on:select={(e) => gradio.dispatch("select", e.detail)}
{wrap}
{datatype}
diff --git a/test/requirements.txt b/test/requirements.txt
index 6a0d5379f985..7f1bc6c4e57f 100644
--- a/test/requirements.txt
+++ b/test/requirements.txt
@@ -118,7 +118,7 @@ packaging==22.0
# scikit-image
# shap
# transformers
-pandas==1.3.5
+pandas==1.5.3
# via
# altair
# shap
diff --git a/test/test_components.py b/test/test_components.py
index de5bb64743d7..ca98f975f6eb 100644
--- a/test/test_components.py
+++ b/test/test_components.py
@@ -1207,6 +1207,7 @@ def test_component_functions(self):
x_data = {
"data": [["Tim", 12, False], ["Jan", 24, True]],
"headers": ["Name", "Age", "Member"],
+ "metadata": None,
}
dataframe_input = gr.Dataframe(headers=["Name", "Age", "Member"])
output = dataframe_input.preprocess(x_data)
@@ -1218,7 +1219,11 @@ def test_component_functions(self):
headers=["Name", "Age", "Member"], label="Dataframe Input"
)
assert dataframe_input.get_config() == {
- "value": {"headers": ["Name", "Age", "Member"], "data": [["", "", ""]]},
+ "value": {
+ "headers": ["Name", "Age", "Member"],
+ "data": [["", "", ""]],
+ "metadata": None,
+ },
"selectable": False,
"headers": ["Name", "Age", "Member"],
"row_count": (1, "dynamic"),
@@ -1250,7 +1255,7 @@ def test_component_functions(self):
dataframe_output = gr.Dataframe()
assert dataframe_output.get_config() == {
- "value": {"headers": [1, 2, 3], "data": [["", "", ""]]},
+ "value": {"headers": [1, 2, 3], "data": [["", "", ""]], "metadata": None},
"selectable": False,
"headers": [1, 2, 3],
"row_count": (1, "dynamic"),
@@ -1281,17 +1286,18 @@ def test_postprocess(self):
"""
dataframe_output = gr.Dataframe()
output = dataframe_output.postprocess([])
- assert output == {"data": [[]], "headers": []}
+ assert output == {"data": [[]], "headers": [], "metadata": None}
output = dataframe_output.postprocess(np.zeros((2, 2)))
- assert output == {"data": [[0, 0], [0, 0]], "headers": [1, 2]}
+ assert output == {"data": [[0, 0], [0, 0]], "headers": [1, 2], "metadata": None}
output = dataframe_output.postprocess([[1, 3, 5]])
- assert output == {"data": [[1, 3, 5]], "headers": [1, 2, 3]}
+ assert output == {"data": [[1, 3, 5]], "headers": [1, 2, 3], "metadata": None}
output = dataframe_output.postprocess(
pd.DataFrame([[2, True], [3, True], [4, False]], columns=["num", "prime"])
)
assert output == {
"headers": ["num", "prime"],
"data": [[2, True], [3, True], [4, False]],
+ "metadata": None,
}
with pytest.raises(ValueError):
gr.Dataframe(type="unknown")
@@ -1302,12 +1308,14 @@ def test_postprocess(self):
assert output == {
"headers": ["one", "two"],
"data": [[2, True], [3, True]],
+ "metadata": None,
}
dataframe_output = gr.Dataframe(headers=["one", "two", "three"])
output = dataframe_output.postprocess([[2, True, "ab", 4], [3, True, "cd", 5]])
assert output == {
"headers": ["one", "two", "three", 4],
"data": [[2, True, "ab", 4], [3, True, "cd", 5]],
+ "metadata": None,
}
def test_dataframe_postprocess_all_types(self):
@@ -1347,6 +1355,7 @@ def test_dataframe_postprocess_all_types(self):
"# Goodbye",
],
],
+ "metadata": None,
}
def test_dataframe_postprocess_only_dates(self):
@@ -1370,6 +1379,44 @@ def test_dataframe_postprocess_only_dates(self):
pd.Timestamp("2022-02-16 00:00:00"),
],
],
+ "metadata": None,
+ }
+
+ def test_dataframe_postprocess_styler(self):
+ component = gr.Dataframe()
+ df = pd.DataFrame(
+ {
+ "name": ["Adam", "Mike"] * 4,
+ "gpa": [1.1, 1.12] * 4,
+ "sat": [800, 800] * 4,
+ }
+ )
+ s = df.style.format(precision=1, decimal=",")
+ output = component.postprocess(s)
+ assert output == {
+ "data": [
+ ["Adam", 1.1, 800],
+ ["Mike", 1.12, 800],
+ ["Adam", 1.1, 800],
+ ["Mike", 1.12, 800],
+ ["Adam", 1.1, 800],
+ ["Mike", 1.12, 800],
+ ["Adam", 1.1, 800],
+ ["Mike", 1.12, 800],
+ ],
+ "headers": ["name", "gpa", "sat"],
+ "metadata": {
+ "display_value": [
+ ["Adam", "1,1", "800"],
+ ["Mike", "1,1", "800"],
+ ["Adam", "1,1", "800"],
+ ["Mike", "1,1", "800"],
+ ["Adam", "1,1", "800"],
+ ["Mike", "1,1", "800"],
+ ["Adam", "1,1", "800"],
+ ["Mike", "1,1", "800"],
+ ]
+ },
}