diff --git a/.changeset/wicked-badgers-smash.md b/.changeset/wicked-badgers-smash.md new file mode 100644 index 000000000000..e26658d2f2c5 --- /dev/null +++ b/.changeset/wicked-badgers-smash.md @@ -0,0 +1,6 @@ +--- +"@gradio/dataframe": minor +"gradio": minor +--- + +feat:Added support for pandas `Styler` object to `gr.DataFrame` (initially just sets the `display_value`) diff --git a/gradio/components/dataframe.py b/gradio/components/dataframe.py index b8981826a5d4..b893716dec83 100644 --- a/gradio/components/dataframe.py +++ b/gradio/components/dataframe.py @@ -3,12 +3,15 @@ from __future__ import annotations import warnings -from typing import TYPE_CHECKING, Any, Callable, Literal +from dataclasses import asdict, dataclass +from typing import Callable, Literal import numpy as np import pandas as pd +import semantic_version from gradio_client.documentation import document, set_documentation_group from gradio_client.serializing import JSONSerializable +from pandas.io.formats.style import Styler from gradio.components.base import IOComponent, _Keywords from gradio.events import ( @@ -18,30 +21,41 @@ Selectable, ) -if TYPE_CHECKING: - from typing import TypedDict +set_documentation_group("component") - class DataframeData(TypedDict): - headers: list[str] - data: list[list[str | int | bool]] +@dataclass +class DataframeData: + """ + This is a dataclass to represent all the data that is sent to or received from the frontend. + """ -set_documentation_group("component") + data: list[list[str | int | bool]] + headers: list[str] | list[int] | None = None + metadata: dict[str, list[list]] | None = None @document() class Dataframe(Changeable, Inputable, Selectable, IOComponent, JSONSerializable): """ Accepts or displays 2D input through a spreadsheet-like component for dataframes. - Preprocessing: passes the uploaded spreadsheet data as a {pandas.DataFrame}, {numpy.array}, {List[List]}, or {List} depending on `type` - Postprocessing: expects a {pandas.DataFrame}, {numpy.array}, {List[List]}, {List}, a {Dict} with keys `data` (and optionally `headers`), or {str} path to a csv, which is rendered in the spreadsheet. + Preprocessing: passes the uploaded spreadsheet data as a {pandas.DataFrame}, {numpy.array}, or {List[List]} depending on `type` + Postprocessing: expects a {pandas.DataFrame}, {pandas.Styler}, {numpy.array}, {List[List]}, {List}, a {Dict} with keys `data` (and optionally `headers`), or {str} path to a csv, which is rendered in the spreadsheet. Examples-format: a {str} filepath to a csv with data, a pandas dataframe, or a list of lists (excluding headers) where each sublist is a row of data. Demos: filter_records, matrix_transpose, tax_calculator """ def __init__( self, - value: list[list[Any]] | Callable | None = None, + value: pd.DataFrame + | Styler + | np.ndarray + | list + | list[list] + | dict + | str + | Callable + | None = None, *, headers: list[str] | None = None, row_count: int | tuple[int, str] = (1, "dynamic"), @@ -67,12 +81,12 @@ def __init__( ): """ Parameters: - value: Default value as a 2-dimensional list of values. If callable, the function will be called whenever the app loads to set the initial value of the component. + value: Default value to display in the DataFrame. If a Styler is provided, it will be used to set the displayed value in the DataFrame (e.g. to set precision of numbers) if the `interactive` is False. If a Callable function is provided, the function will be called whenever the app loads to set the initial value of the component. headers: List of str header names. If None, no headers are shown. row_count: Limit number of rows for input and decide whether user can create new rows. The first element of the tuple is an `int`, the row count; the second should be 'fixed' or 'dynamic', the new row behaviour. If an `int` is passed the rows default to 'dynamic' col_count: Limit number of columns for input and decide whether user can create new columns. The first element of the tuple is an `int`, the number of columns; the second should be 'fixed' or 'dynamic', the new column behaviour. If an `int` is passed the columns default to 'dynamic' datatype: Datatype of values in sheet. Can be provided per column as a list of strings, or for the entire sheet as a single string. Valid datatypes are "str", "number", "bool", "date", and "markdown". - type: Type of value to be returned by component. "pandas" for pandas dataframe, "numpy" for numpy array, or "array" for a Python array. + type: Type of value to be returned by component. "pandas" for pandas dataframe, "numpy" for numpy array, or "array" for a Python list of lists. label: component name in interface. max_rows: Deprecated and has no effect. Use `row_count` instead. max_cols: Deprecated and has no effect. Use `col_count` instead. @@ -157,7 +171,15 @@ def __init__( @staticmethod def update( - value: Any | Literal[_Keywords.NO_VALUE] | None = _Keywords.NO_VALUE, + value: pd.DataFrame + | Styler + | np.ndarray + | list + | list[list] + | dict + | str + | Literal[_Keywords.NO_VALUE] + | None = _Keywords.NO_VALUE, max_rows: int | None = None, max_cols: str | None = None, label: str | None = None, @@ -187,22 +209,23 @@ def update( "__type__": "update", } - def preprocess(self, x: DataframeData): + def preprocess(self, x: dict) -> pd.DataFrame | np.ndarray | list: """ Parameters: - x: 2D array of str, numeric, or bool data + x: Dictionary equivalent of DataframeData containing `headers`, `data`, and optionally `metadata` keys Returns: - Dataframe in requested format + The Dataframe data in requested format """ + value = DataframeData(**x) if self.type == "pandas": - if x.get("headers") is not None: - return pd.DataFrame(x["data"], columns=x.get("headers")) + if value.headers is not None: + return pd.DataFrame(value.data, columns=value.headers) else: - return pd.DataFrame(x["data"]) + return pd.DataFrame(value.data) if self.type == "numpy": - return np.array(x["data"]) + return np.array(value.data) elif self.type == "array": - return x["data"] + return value.data else: raise ValueError( "Unknown type: " @@ -211,7 +234,8 @@ def preprocess(self, x: DataframeData): ) def postprocess( - self, y: str | pd.DataFrame | np.ndarray | list[list[str | float]] | dict + self, + y: pd.DataFrame | Styler | np.ndarray | list | list[list] | dict | str | None, ) -> dict: """ Parameters: @@ -222,15 +246,31 @@ def postprocess( if y is None: return self.postprocess(self.empty_input) if isinstance(y, dict): - return y - if isinstance(y, (str, pd.DataFrame)): - if isinstance(y, str): - y = pd.read_csv(y) - return { - "headers": list(y.columns), # type: ignore - "data": y.to_dict(orient="split")["data"], # type: ignore - } - if isinstance(y, (np.ndarray, list)): + value = DataframeData(**y) + elif isinstance(y, Styler): + if semantic_version.Version(pd.__version__) < semantic_version.Version( + "1.5.0" + ): + raise ValueError( + "Styler objects are only supported in pandas version 1.5.0 or higher. Please try: `pip install --upgrade pandas` to use this feature." + ) + if self.interactive: + warnings.warn( + "Cannot display Styler object in interactive mode. Will display as a regular pandas dataframe instead." + ) + df: pd.DataFrame = y.data # type: ignore + value = DataframeData( + headers=list(df.columns), + data=df.to_dict(orient="split")["data"], + metadata=self.__extract_metadata(y), + ) + elif isinstance(y, (str, pd.DataFrame)): + df = pd.read_csv(y) if isinstance(y, str) else y + value = DataframeData( + headers=list(df.columns), + data=df.to_dict(orient="split")["data"], + ) + elif isinstance(y, (np.ndarray, list)): if len(y) == 0: return self.postprocess([[]]) if isinstance(y, np.ndarray): @@ -238,7 +278,6 @@ def postprocess( assert isinstance(y, list), "output cannot be converted to list" _headers = self.headers - if len(self.headers) < len(y[0]): _headers = [ *self.headers, @@ -247,11 +286,27 @@ def postprocess( elif len(self.headers) > len(y[0]): _headers = self.headers[: len(y[0])] - return { - "headers": _headers, - "data": y, - } - raise ValueError("Cannot process value as a Dataframe") + value = DataframeData( + headers=_headers, + data=y, + ) + else: + raise ValueError(f"Cannot process value as a Dataframe: {y}") + return asdict(value) + + @staticmethod + def __extract_metadata(df: Styler) -> dict[str, list[list]]: + metadata = {"display_value": []} + style_data = df._compute()._translate(None, None) # type: ignore + for i in range(len(style_data["body"])): + metadata["display_value"].append([]) + for j in range(len(style_data["body"][i])): + cell_type = style_data["body"][i][j]["type"] + if cell_type != "td": + continue + display_value = style_data["body"][i][j]["display_value"] + metadata["display_value"][i].append(display_value) + return metadata @staticmethod def __process_counts(count, default=3) -> tuple[int, str]: diff --git a/js/dataframe/Dataframe.stories.svelte b/js/dataframe/Dataframe.stories.svelte index ebe7229f3ebc..210c0b09d472 100644 --- a/js/dataframe/Dataframe.stories.svelte +++ b/js/dataframe/Dataframe.stories.svelte @@ -51,6 +51,29 @@ }} /> + + `${i + 1}`) + .map((_, i) => `${i + 1}`), + + metadata: null }; } @@ -87,7 +87,7 @@ {label} {row_count} {col_count} - values={value} + {value} {headers} on:change={({ detail }) => { value = detail; diff --git a/js/dataframe/shared/EditableCell.svelte b/js/dataframe/shared/EditableCell.svelte index 14958d72e86a..87e69de9a60f 100644 --- a/js/dataframe/shared/EditableCell.svelte +++ b/js/dataframe/shared/EditableCell.svelte @@ -5,6 +5,7 @@ export let edit: boolean; export let value: string | number = ""; + export let display_value: string | null = null; export let header = false; export let datatype: | "str" @@ -20,6 +21,7 @@ }[]; export let clear_on_focus = false; export let select_on_focus = false; + export let editable = true; const dispatch = createEventDispatcher(); @@ -71,7 +73,7 @@ chatbot={false} /> {:else} - {value} + {editable ? value : display_value || value} {/if} diff --git a/js/dataframe/shared/Table.svelte b/js/dataframe/shared/Table.svelte index 055d6a3b4164..19d5b6fafb11 100644 --- a/js/dataframe/shared/Table.svelte +++ b/js/dataframe/shared/Table.svelte @@ -9,15 +9,19 @@ import type { SelectData } from "@gradio/utils"; import { _ } from "svelte-i18n"; import VirtualTable from "./VirtualTable.svelte"; - - type Datatype = "str" | "markdown" | "html" | "number" | "bool" | "date"; + import type { + Headers, + HeadersWithIDs, + Data, + Metadata, + Datatype + } from "../shared/utils"; export let datatype: Datatype | Datatype[]; export let label: string | null = null; - export let headers: string[] = []; - export let values: - | (string | number)[][] - | { data: (string | number)[][]; headers: string[] } = [[]]; + export let headers: Headers = []; + let values: (string | number)[][]; + export let value: { data: Data; headers: Headers; metadata: Metadata } | null; export let col_count: [number, "fixed" | "dynamic"]; export let row_count: [number, "fixed" | "dynamic"]; export let latex_delimiters: { @@ -30,18 +34,24 @@ export let wrap = false; export let height = 500; let selected: false | [number, number] = false; + let display_value: string[][] | null = value?.metadata?.display_value ?? null; $: { - if (values && !Array.isArray(values)) { - headers = values.headers; - values = values.data; + if (value) { + headers = value.headers; + values = value.data; + display_value = value?.metadata?.display_value ?? null; } else if (values === null) { values = []; } } const dispatch = createEventDispatcher<{ - change: { data: (string | number)[][]; headers: string[] }; + change: { + data: (string | number)[][]; + headers: string[]; + metadata: Metadata; + }; select: SelectData; }>(); @@ -64,12 +74,10 @@ let data_binding: Record = {}; - type Headers = { value: string; id: string }[]; - function make_id(): string { return Math.random().toString(36).substring(2, 15); } - function make_headers(_head: string[]): Headers { + function make_headers(_head: Headers): HeadersWithIDs { let _h = _head || []; if (col_count[1] === "fixed" && _h.length < col_count[0]) { const fill = Array(col_count[0] - _h.length) @@ -152,7 +160,8 @@ $: _headers && dispatch("change", { data: data.map((r) => r.map(({ value }) => value)), - headers: _headers.map((h) => h.value) + headers: _headers.map((h) => h.value), + metadata: editable ? null : { display_value: display_value } }); function get_sort_status( @@ -545,6 +554,7 @@ function sort_data( _data: typeof data, + _display_value: string[][] | null, col?: number, dir?: SortDirection ): void { @@ -556,12 +566,29 @@ if (typeof col !== "number" || !dir) { return; } + const indices = [...Array(_data.length).keys()]; + if (dir === "asc") { - _data.sort((a, b) => (a[col].value < b[col].value ? -1 : 1)); + indices.sort((i, j) => + _data[i][col].value < _data[j][col].value ? -1 : 1 + ); } else if (dir === "des") { - _data.sort((a, b) => (a[col].value > b[col].value ? -1 : 1)); + indices.sort((i, j) => + _data[i][col].value > _data[j][col].value ? -1 : 1 + ); + } else { + return; } + // sort both data and display_value in place based on the values in data + const tempData = [..._data]; + const tempData2 = _display_value ? [..._display_value] : null; + indices.forEach((originalIndex, sortedIndex) => { + _data[sortedIndex] = tempData[originalIndex]; + if (_display_value && tempData2) + _display_value[sortedIndex] = tempData2[originalIndex]; + }); + data = data; if (id) { @@ -570,7 +597,7 @@ } } - $: sort_data(data, sort_by, sort_direction); + $: sort_data(data, display_value, sort_by, sort_direction); $: selected_index = !!selected && selected[0]; @@ -751,7 +778,9 @@ ((clear_on_focus = false), parent.focus())} diff --git a/js/dataframe/shared/utils.ts b/js/dataframe/shared/utils.ts new file mode 100644 index 000000000000..09d4e0940d01 --- /dev/null +++ b/js/dataframe/shared/utils.ts @@ -0,0 +1,7 @@ +export type Headers = string[]; +export type Data = (string | number)[][]; +export type Datatype = "str" | "markdown" | "html" | "number" | "bool" | "date"; +export type Metadata = { + [key: string]: string[][] | null; +} | null; +export type HeadersWithIDs = { value: string; id: string }[]; diff --git a/js/dataframe/static/StaticDataframe.svelte b/js/dataframe/static/StaticDataframe.svelte index 921bf394717d..dc2058ec24f2 100644 --- a/js/dataframe/static/StaticDataframe.svelte +++ b/js/dataframe/static/StaticDataframe.svelte @@ -5,18 +5,15 @@ import Table from "../shared"; import { StatusTracker } from "@gradio/statustracker"; import type { LoadingStatus } from "@gradio/statustracker"; - - type Headers = string[]; - type Data = (string | number)[][]; - type Datatype = "str" | "markdown" | "html" | "number" | "bool" | "date"; - + import type { Headers, Data, Metadata, Datatype } from "../shared/utils"; export let headers: Headers = []; export let elem_id = ""; export let elem_classes: string[] = []; export let visible = true; - export let value: { data: Data; headers: Headers } = { + export let value: { data: Data; headers: Headers; metadata: Metadata } = { data: [["", "", ""]], - headers: ["1", "2", "3"] + headers: ["1", "2", "3"], + metadata: null }; let old_value: string = JSON.stringify(value); export let value_is_output = false; @@ -56,7 +53,6 @@ handle_change(); } } - if ( (Array.isArray(value) && value?.[0]?.length === 0) || value.data?.[0]?.length === 0 @@ -65,7 +61,8 @@ data: [Array(col_count?.[0] || 3).fill("")], headers: Array(col_count?.[0] || 3) .fill("") - .map((_, i) => `${i + 1}`) + .map((_, i) => `${i + 1}`), + metadata: null }; } @@ -85,11 +82,8 @@ {label} {row_count} {col_count} - values={value} + {value} {headers} - on:change={({ detail }) => { - value = detail; - }} on:select={(e) => gradio.dispatch("select", e.detail)} {wrap} {datatype} diff --git a/test/requirements.txt b/test/requirements.txt index 6a0d5379f985..7f1bc6c4e57f 100644 --- a/test/requirements.txt +++ b/test/requirements.txt @@ -118,7 +118,7 @@ packaging==22.0 # scikit-image # shap # transformers -pandas==1.3.5 +pandas==1.5.3 # via # altair # shap diff --git a/test/test_components.py b/test/test_components.py index de5bb64743d7..ca98f975f6eb 100644 --- a/test/test_components.py +++ b/test/test_components.py @@ -1207,6 +1207,7 @@ def test_component_functions(self): x_data = { "data": [["Tim", 12, False], ["Jan", 24, True]], "headers": ["Name", "Age", "Member"], + "metadata": None, } dataframe_input = gr.Dataframe(headers=["Name", "Age", "Member"]) output = dataframe_input.preprocess(x_data) @@ -1218,7 +1219,11 @@ def test_component_functions(self): headers=["Name", "Age", "Member"], label="Dataframe Input" ) assert dataframe_input.get_config() == { - "value": {"headers": ["Name", "Age", "Member"], "data": [["", "", ""]]}, + "value": { + "headers": ["Name", "Age", "Member"], + "data": [["", "", ""]], + "metadata": None, + }, "selectable": False, "headers": ["Name", "Age", "Member"], "row_count": (1, "dynamic"), @@ -1250,7 +1255,7 @@ def test_component_functions(self): dataframe_output = gr.Dataframe() assert dataframe_output.get_config() == { - "value": {"headers": [1, 2, 3], "data": [["", "", ""]]}, + "value": {"headers": [1, 2, 3], "data": [["", "", ""]], "metadata": None}, "selectable": False, "headers": [1, 2, 3], "row_count": (1, "dynamic"), @@ -1281,17 +1286,18 @@ def test_postprocess(self): """ dataframe_output = gr.Dataframe() output = dataframe_output.postprocess([]) - assert output == {"data": [[]], "headers": []} + assert output == {"data": [[]], "headers": [], "metadata": None} output = dataframe_output.postprocess(np.zeros((2, 2))) - assert output == {"data": [[0, 0], [0, 0]], "headers": [1, 2]} + assert output == {"data": [[0, 0], [0, 0]], "headers": [1, 2], "metadata": None} output = dataframe_output.postprocess([[1, 3, 5]]) - assert output == {"data": [[1, 3, 5]], "headers": [1, 2, 3]} + assert output == {"data": [[1, 3, 5]], "headers": [1, 2, 3], "metadata": None} output = dataframe_output.postprocess( pd.DataFrame([[2, True], [3, True], [4, False]], columns=["num", "prime"]) ) assert output == { "headers": ["num", "prime"], "data": [[2, True], [3, True], [4, False]], + "metadata": None, } with pytest.raises(ValueError): gr.Dataframe(type="unknown") @@ -1302,12 +1308,14 @@ def test_postprocess(self): assert output == { "headers": ["one", "two"], "data": [[2, True], [3, True]], + "metadata": None, } dataframe_output = gr.Dataframe(headers=["one", "two", "three"]) output = dataframe_output.postprocess([[2, True, "ab", 4], [3, True, "cd", 5]]) assert output == { "headers": ["one", "two", "three", 4], "data": [[2, True, "ab", 4], [3, True, "cd", 5]], + "metadata": None, } def test_dataframe_postprocess_all_types(self): @@ -1347,6 +1355,7 @@ def test_dataframe_postprocess_all_types(self): "# Goodbye", ], ], + "metadata": None, } def test_dataframe_postprocess_only_dates(self): @@ -1370,6 +1379,44 @@ def test_dataframe_postprocess_only_dates(self): pd.Timestamp("2022-02-16 00:00:00"), ], ], + "metadata": None, + } + + def test_dataframe_postprocess_styler(self): + component = gr.Dataframe() + df = pd.DataFrame( + { + "name": ["Adam", "Mike"] * 4, + "gpa": [1.1, 1.12] * 4, + "sat": [800, 800] * 4, + } + ) + s = df.style.format(precision=1, decimal=",") + output = component.postprocess(s) + assert output == { + "data": [ + ["Adam", 1.1, 800], + ["Mike", 1.12, 800], + ["Adam", 1.1, 800], + ["Mike", 1.12, 800], + ["Adam", 1.1, 800], + ["Mike", 1.12, 800], + ["Adam", 1.1, 800], + ["Mike", 1.12, 800], + ], + "headers": ["name", "gpa", "sat"], + "metadata": { + "display_value": [ + ["Adam", "1,1", "800"], + ["Mike", "1,1", "800"], + ["Adam", "1,1", "800"], + ["Mike", "1,1", "800"], + ["Adam", "1,1", "800"], + ["Mike", "1,1", "800"], + ["Adam", "1,1", "800"], + ["Mike", "1,1", "800"], + ] + }, }