Skip to content

Commit

Permalink
feat: TUI tests (#20)
Browse files Browse the repository at this point in the history
* init tests

* fix DOM initial rendering

* ok coverage
  • Loading branch information
FBruzzesi authored Jun 10, 2024
1 parent a4d04f0 commit 8c79747
Show file tree
Hide file tree
Showing 8 changed files with 240 additions and 21 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pull-request.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ jobs:
- name: Install dependencies and run tests
run: |
uv pip install -e ".[all]" --system
uv pip install pytest pytest-cov pytest-xdist --system
uv pip install -r requirements/test.txt --system
make test-cov
- name: Install and run mypy
run: |
Expand Down
7 changes: 7 additions & 0 deletions requirements/test.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
anyio
pytest
pytest-asyncio
pytest-cov
pytest-tornasync
pytest-trio
pytest-xdist
1 change: 1 addition & 0 deletions sksmithy/_static/tui.tcss
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

Screen {
align: center middle;
min-width: 100vw;
}

Header {
Expand Down
21 changes: 9 additions & 12 deletions sksmithy/tui/_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,9 @@ def on_input_change(self: Self, event: Input.Changed) -> None:
timeout=5,
)
else:
output_file = self.app.query_one("#output_file", Input)
output_file.value = f"{event.value.lower()}.py"
output_file = self.app.query_one("#output-file", Input)
if not output_file.value:
output_file.value = f"{event.value.lower()}.py"


class Estimator(Container):
Expand Down Expand Up @@ -210,8 +211,8 @@ class DestinationFile(Container):
"""Destination file input component."""

def compose(self: Self) -> ComposeResult:
yield Prompt(PROMPT_OUTPUT, classes="label", id="output_prompt")
yield Input(placeholder="mightyestimator.py", id="output_file")
yield Prompt(PROMPT_OUTPUT, classes="label")
yield Input(placeholder="mightyestimator.py", id="output-file")


class ForgeButton(Container):
Expand All @@ -220,10 +221,10 @@ class ForgeButton(Container):
def compose(self: Self) -> ComposeResult:
yield Button.success(
label="Forge ⚒️",
id="forge_btn",
id="forge-btn",
)

@on(Button.Pressed, "#forge_btn")
@on(Button.Pressed, "#forge-btn")
def on_forge(self: Self, _: Button.Pressed) -> None: # noqa: C901
errors = []

Expand All @@ -237,7 +238,7 @@ def on_forge(self: Self, _: Button.Pressed) -> None: # noqa: C901
predict_proba = self.app.query_one("#predict_proba", Switch).value
decision_function = self.app.query_one("#decision_function", Switch).value

output_file = self.app.query_one("#output_file", Input).value
output_file = self.app.query_one("#output-file", Input).value

match name_parser(name_input):
case Ok(name):
Expand All @@ -249,7 +250,7 @@ def on_forge(self: Self, _: Button.Pressed) -> None: # noqa: C901
case str(v):
estimator_type = EstimatorType(v)
case Select.BLANK:
errors.append("Estimator cannot be None!")
errors.append("Estimator cannot be empty!")

match params_parser(required_params):
case Ok(required):
Expand Down Expand Up @@ -311,10 +312,6 @@ class ForgeRow(Grid):
"""Row grid for forge."""



forge_row = ForgeRow(Static(), Static(), ForgeButton(), DestinationFile(), Static(), Static(), id="forge_row")


class Title(Static):
pass

Expand Down
30 changes: 23 additions & 7 deletions sksmithy/tui/_tui.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,21 @@
from textual.app import App, ComposeResult
from textual.containers import Container, Horizontal, ScrollableContainer
from textual.reactive import reactive
from textual.widgets import Button, Footer, Header, Rule
from textual.widgets import Button, Footer, Header, Rule, Static

from sksmithy.tui._components import (
DecisionFunction,
DestinationFile,
Estimator,
ForgeButton,
ForgeRow,
Linear,
Name,
Optional,
PredictProba,
Required,
SampleWeight,
Sidebar,
forge_row,
)

if sys.version_info >= (3, 11): # pragma: no cover
Expand All @@ -41,6 +43,13 @@ class ForgeTUI(App):

show_sidebar = reactive(False) # noqa: FBT003

def on_mount(self: Self) -> None:
"""Compose on mount.
Q: is this needed?
"""
self.compose()

def compose(self: Self) -> ComposeResult:
"""Create child widgets for the app."""
yield Container(
Expand All @@ -51,18 +60,25 @@ def compose(self: Self) -> ComposeResult:
Horizontal(SampleWeight(), Linear()),
Horizontal(PredictProba(), DecisionFunction()),
Rule(),
forge_row,
ForgeRow(
Static(),
Static(),
ForgeButton(),
DestinationFile(),
Static(),
Static(),
),
Rule(),
),
Sidebar(classes="-hidden"),
Footer(),
)

def action_toggle_dark(self: Self) -> None:
def action_toggle_dark(self: Self) -> None: # pragma: no cover
"""Toggle dark mode."""
self.dark = not self.dark

def action_toggle_sidebar(self: Self) -> None:
def action_toggle_sidebar(self: Self) -> None: # pragma: no cover
"""Toggle sidebar component."""
sidebar = self.query_one(Sidebar)
self.set_focus(None)
Expand All @@ -76,10 +92,10 @@ def action_toggle_sidebar(self: Self) -> None:

def action_forge(self: Self) -> None:
"""Press forge button."""
forge_btn = self.query_one("#forge_btn", Button)
forge_btn = self.query_one("#forge-btn", Button)
forge_btn.press()


if __name__ == "__main__":
if __name__ == "__main__": # pragma: no cover
tui = ForgeTUI()
tui.run()
2 changes: 1 addition & 1 deletion sksmithy/tui/_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

class _BaseValidator(Validator):
@staticmethod
def parser(value: str) -> Result[str | list[str], str]:
def parser(value: str) -> Result[str | list[str], str]: # pragma: no cover
raise NotImplementedError

def validate(self: Self, value: str) -> ValidationResult:
Expand Down
20 changes: 20 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,26 @@ def required(request: pytest.FixtureRequest) -> list[str]:
return request.param


@pytest.fixture(
params=[
("a,a", "Found repeated parameters!"),
("a-a", "The following parameters are invalid python identifiers: ('a-a',)"),
]
)
def invalid_required(request: pytest.FixtureRequest) -> tuple[str, str]:
return request.param


@pytest.fixture(
params=[
("b,b", "Found repeated parameters!"),
("b b", "The following parameters are invalid python identifiers: ('b b',)"),
]
)
def invalid_optional(request: pytest.FixtureRequest) -> tuple[str, str]:
return request.param


@pytest.fixture(params=[["mu", "sigma"], []])
def optional(request: pytest.FixtureRequest) -> list[str]:
return request.param
Expand Down
178 changes: 178 additions & 0 deletions tests/test_tui.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
from pathlib import Path

import pytest
from textual.widgets import Button, Input, Select, Switch

from sksmithy._models import EstimatorType
from sksmithy.tui import ForgeTUI


async def test_smoke() -> None:
"""Basic smoke test."""
app = ForgeTUI()
async with app.run_test(size=None) as pilot:
await pilot.pause()
assert pilot is not None

await pilot.pause()
await pilot.exit(0)


@pytest.mark.parametrize(
("name_", "err_msg"),
[
("MightyEstimator", ""),
("not-valid-name", "`not-valid-name` is not a valid python class name!"),
("class", "`class` is a python reserved keyword!"),
],
)
async def test_name(name_: str, err_msg: str) -> None:
"""Test `name` text_input component."""
app = ForgeTUI()
async with app.run_test(size=None) as pilot:
name_comp = pilot.app.query_one("#name", Input)
name_comp.value = name_
await pilot.pause()

assert (not name_comp.is_valid) == bool(err_msg)

notifications = list(pilot.app._notifications) # noqa: SLF001
assert len(notifications) == int(bool(err_msg))

if notifications:
assert notifications[0].message == err_msg


async def test_estimator_interaction(estimator: EstimatorType) -> None:
"""Test that all toggle components interact correctly with the selected estimator."""
app = ForgeTUI()
async with app.run_test(size=None) as pilot:
pilot.app.query_one("#estimator", Select).value = estimator.value
await pilot.pause()

assert (not pilot.app.query_one("#linear", Switch).disabled) == (
estimator in {EstimatorType.ClassifierMixin, EstimatorType.RegressorMixin}
)
assert (not pilot.app.query_one("#predict_proba", Switch).disabled) == (
estimator in {EstimatorType.ClassifierMixin, EstimatorType.OutlierMixin}
)

assert (not pilot.app.query_one("#decision_function", Switch).disabled) == (
estimator == EstimatorType.ClassifierMixin
)

if estimator == EstimatorType.ClassifierMixin:
linear = pilot.app.query_one("#linear", Switch)
linear.value = True

await pilot.pause()
assert pilot.app.query_one("#decision_function", Switch).disabled


async def test_valid_params() -> None:
"""Test required and optional params interaction."""
app = ForgeTUI()
required_ = "a,b"
optional_ = "c,d"
async with app.run_test(size=None) as pilot:
required_comp = pilot.app.query_one("#required", Input)
optional_comp = pilot.app.query_one("#optional", Input)

required_comp.value = required_
optional_comp.value = optional_

await required_comp.action_submit()
await optional_comp.action_submit()
await pilot.pause(0.01)

notifications = list(pilot.app._notifications) # noqa: SLF001
assert not notifications


@pytest.mark.parametrize(("required_", "optional_"), [("a,b", "a"), ("a", "a,b")])
async def test_duplicated_params(required_: str, optional_: str) -> None:
app = ForgeTUI()
msg = "The following parameters are duplicated between required and optional: {'a'}"

async with app.run_test(size=None) as pilot:
required_comp = pilot.app.query_one("#required", Input)
optional_comp = pilot.app.query_one("#optional", Input)

required_comp.value = required_
optional_comp.value = optional_

await required_comp.action_submit()
await optional_comp.action_submit()
await pilot.pause()

forge_btn = pilot.app.query_one("#forge-btn", Button)
forge_btn.action_press()
await pilot.pause()

assert all(msg in n.message for n in pilot.app._notifications) # noqa: SLF001


async def test_forge_raise() -> None:
"""Test forge button and all of its interactions."""
app = ForgeTUI()
async with app.run_test(size=None) as pilot:
required_comp = pilot.app.query_one("#required", Input)
optional_comp = pilot.app.query_one("#optional", Input)

required_comp.value = "a,a"
optional_comp.value = "b b"

await required_comp.action_submit()
await optional_comp.action_submit()
await pilot.pause()

forge_btn = pilot.app.query_one("#forge-btn", Button)
forge_btn.action_press()
await pilot.pause()

m1, m2, m3 = (n.message for n in pilot.app._notifications) # noqa: SLF001

assert "Found repeated parameters!" in m1
assert "The following parameters are invalid python identifiers: ('b b',)" in m2

assert "Name cannot be empty!" in m3
assert "Estimator cannot be empty!" in m3
assert "Outfile file cannot be empty!" in m3
assert "Found repeated parameters!" in m3
assert "The following parameters are invalid python identifiers: ('b b',)" in m3


@pytest.mark.parametrize("use_binding", [True, False])
async def test_forge(tmp_path: Path, use_binding: bool) -> None:
"""Test forge button and all of its interactions."""
app = ForgeTUI()
name = "MightyEstimator"
estimator = "classifier"
async with app.run_test(size=None) as pilot:
name_comp = pilot.app.query_one("#name", Input)
estimator_comp = pilot.app.query_one("#estimator", Select)
await pilot.pause()

output_file_comp = pilot.app.query_one("#output-file", Input)

name_comp.value = name
estimator_comp.value = estimator

await pilot.pause()

output_file = tmp_path / (f"{name.lower()}.py")
output_file_comp.value = str(output_file)
await output_file_comp.action_submit()
await pilot.pause()

if use_binding:
await pilot.press("F")
else:
forge_btn = pilot.app.query_one("#forge-btn", Button)
forge_btn.action_press()
await pilot.pause()

notification = next(iter(pilot.app._notifications)) # noqa: SLF001

assert f"Template forged at {output_file!s}" in notification.message
assert output_file.exists()

0 comments on commit 8c79747

Please sign in to comment.