diff --git a/.github/workflows/pull-request.yaml b/.github/workflows/pull-request.yaml index 2da67ab..cad0dbf 100644 --- a/.github/workflows/pull-request.yaml +++ b/.github/workflows/pull-request.yaml @@ -40,7 +40,7 @@ jobs: - name: Install dependencies and run tests run: | uv pip install -e ".[all]" --system - uv pip install pytest pytest-cov pytest-xdist --system + uv pip install -r requirements/test.txt --system make test-cov - name: Install and run mypy run: | diff --git a/requirements/test.txt b/requirements/test.txt new file mode 100644 index 0000000..8b5374e --- /dev/null +++ b/requirements/test.txt @@ -0,0 +1,7 @@ +anyio +pytest +pytest-asyncio +pytest-cov +pytest-tornasync +pytest-trio +pytest-xdist \ No newline at end of file diff --git a/sksmithy/_static/tui.tcss b/sksmithy/_static/tui.tcss index 5ed90ce..4e38a73 100644 --- a/sksmithy/_static/tui.tcss +++ b/sksmithy/_static/tui.tcss @@ -12,6 +12,7 @@ Screen { align: center middle; + min-width: 100vw; } Header { diff --git a/sksmithy/tui/_components.py b/sksmithy/tui/_components.py index e93ae99..93f497a 100644 --- a/sksmithy/tui/_components.py +++ b/sksmithy/tui/_components.py @@ -55,8 +55,9 @@ def on_input_change(self: Self, event: Input.Changed) -> None: timeout=5, ) else: - output_file = self.app.query_one("#output_file", Input) - output_file.value = f"{event.value.lower()}.py" + output_file = self.app.query_one("#output-file", Input) + if not output_file.value: + output_file.value = f"{event.value.lower()}.py" class Estimator(Container): @@ -210,8 +211,8 @@ class DestinationFile(Container): """Destination file input component.""" def compose(self: Self) -> ComposeResult: - yield Prompt(PROMPT_OUTPUT, classes="label", id="output_prompt") - yield Input(placeholder="mightyestimator.py", id="output_file") + yield Prompt(PROMPT_OUTPUT, classes="label") + yield Input(placeholder="mightyestimator.py", id="output-file") class ForgeButton(Container): @@ -220,10 +221,10 @@ class ForgeButton(Container): def compose(self: Self) -> ComposeResult: yield Button.success( label="Forge ⚒️", - id="forge_btn", + id="forge-btn", ) - @on(Button.Pressed, "#forge_btn") + @on(Button.Pressed, "#forge-btn") def on_forge(self: Self, _: Button.Pressed) -> None: # noqa: C901 errors = [] @@ -237,7 +238,7 @@ def on_forge(self: Self, _: Button.Pressed) -> None: # noqa: C901 predict_proba = self.app.query_one("#predict_proba", Switch).value decision_function = self.app.query_one("#decision_function", Switch).value - output_file = self.app.query_one("#output_file", Input).value + output_file = self.app.query_one("#output-file", Input).value match name_parser(name_input): case Ok(name): @@ -249,7 +250,7 @@ def on_forge(self: Self, _: Button.Pressed) -> None: # noqa: C901 case str(v): estimator_type = EstimatorType(v) case Select.BLANK: - errors.append("Estimator cannot be None!") + errors.append("Estimator cannot be empty!") match params_parser(required_params): case Ok(required): @@ -311,10 +312,6 @@ class ForgeRow(Grid): """Row grid for forge.""" - -forge_row = ForgeRow(Static(), Static(), ForgeButton(), DestinationFile(), Static(), Static(), id="forge_row") - - class Title(Static): pass diff --git a/sksmithy/tui/_tui.py b/sksmithy/tui/_tui.py index 6d6fb03..ae47015 100644 --- a/sksmithy/tui/_tui.py +++ b/sksmithy/tui/_tui.py @@ -5,11 +5,14 @@ from textual.app import App, ComposeResult from textual.containers import Container, Horizontal, ScrollableContainer from textual.reactive import reactive -from textual.widgets import Button, Footer, Header, Rule +from textual.widgets import Button, Footer, Header, Rule, Static from sksmithy.tui._components import ( DecisionFunction, + DestinationFile, Estimator, + ForgeButton, + ForgeRow, Linear, Name, Optional, @@ -17,7 +20,6 @@ Required, SampleWeight, Sidebar, - forge_row, ) if sys.version_info >= (3, 11): # pragma: no cover @@ -41,6 +43,13 @@ class ForgeTUI(App): show_sidebar = reactive(False) # noqa: FBT003 + def on_mount(self: Self) -> None: + """Compose on mount. + + Q: is this needed? + """ + self.compose() + def compose(self: Self) -> ComposeResult: """Create child widgets for the app.""" yield Container( @@ -51,18 +60,25 @@ def compose(self: Self) -> ComposeResult: Horizontal(SampleWeight(), Linear()), Horizontal(PredictProba(), DecisionFunction()), Rule(), - forge_row, + ForgeRow( + Static(), + Static(), + ForgeButton(), + DestinationFile(), + Static(), + Static(), + ), Rule(), ), Sidebar(classes="-hidden"), Footer(), ) - def action_toggle_dark(self: Self) -> None: + def action_toggle_dark(self: Self) -> None: # pragma: no cover """Toggle dark mode.""" self.dark = not self.dark - def action_toggle_sidebar(self: Self) -> None: + def action_toggle_sidebar(self: Self) -> None: # pragma: no cover """Toggle sidebar component.""" sidebar = self.query_one(Sidebar) self.set_focus(None) @@ -76,10 +92,10 @@ def action_toggle_sidebar(self: Self) -> None: def action_forge(self: Self) -> None: """Press forge button.""" - forge_btn = self.query_one("#forge_btn", Button) + forge_btn = self.query_one("#forge-btn", Button) forge_btn.press() -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover tui = ForgeTUI() tui.run() diff --git a/sksmithy/tui/_validators.py b/sksmithy/tui/_validators.py index 8638b7f..8dc1eaa 100644 --- a/sksmithy/tui/_validators.py +++ b/sksmithy/tui/_validators.py @@ -17,7 +17,7 @@ class _BaseValidator(Validator): @staticmethod - def parser(value: str) -> Result[str | list[str], str]: + def parser(value: str) -> Result[str | list[str], str]: # pragma: no cover raise NotImplementedError def validate(self: Self, value: str) -> ValidationResult: diff --git a/tests/conftest.py b/tests/conftest.py index 4872254..e231dd9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,6 +19,26 @@ def required(request: pytest.FixtureRequest) -> list[str]: return request.param +@pytest.fixture( + params=[ + ("a,a", "Found repeated parameters!"), + ("a-a", "The following parameters are invalid python identifiers: ('a-a',)"), + ] +) +def invalid_required(request: pytest.FixtureRequest) -> tuple[str, str]: + return request.param + + +@pytest.fixture( + params=[ + ("b,b", "Found repeated parameters!"), + ("b b", "The following parameters are invalid python identifiers: ('b b',)"), + ] +) +def invalid_optional(request: pytest.FixtureRequest) -> tuple[str, str]: + return request.param + + @pytest.fixture(params=[["mu", "sigma"], []]) def optional(request: pytest.FixtureRequest) -> list[str]: return request.param diff --git a/tests/test_tui.py b/tests/test_tui.py new file mode 100644 index 0000000..ea9c7d1 --- /dev/null +++ b/tests/test_tui.py @@ -0,0 +1,178 @@ +from pathlib import Path + +import pytest +from textual.widgets import Button, Input, Select, Switch + +from sksmithy._models import EstimatorType +from sksmithy.tui import ForgeTUI + + +async def test_smoke() -> None: + """Basic smoke test.""" + app = ForgeTUI() + async with app.run_test(size=None) as pilot: + await pilot.pause() + assert pilot is not None + + await pilot.pause() + await pilot.exit(0) + + +@pytest.mark.parametrize( + ("name_", "err_msg"), + [ + ("MightyEstimator", ""), + ("not-valid-name", "`not-valid-name` is not a valid python class name!"), + ("class", "`class` is a python reserved keyword!"), + ], +) +async def test_name(name_: str, err_msg: str) -> None: + """Test `name` text_input component.""" + app = ForgeTUI() + async with app.run_test(size=None) as pilot: + name_comp = pilot.app.query_one("#name", Input) + name_comp.value = name_ + await pilot.pause() + + assert (not name_comp.is_valid) == bool(err_msg) + + notifications = list(pilot.app._notifications) # noqa: SLF001 + assert len(notifications) == int(bool(err_msg)) + + if notifications: + assert notifications[0].message == err_msg + + +async def test_estimator_interaction(estimator: EstimatorType) -> None: + """Test that all toggle components interact correctly with the selected estimator.""" + app = ForgeTUI() + async with app.run_test(size=None) as pilot: + pilot.app.query_one("#estimator", Select).value = estimator.value + await pilot.pause() + + assert (not pilot.app.query_one("#linear", Switch).disabled) == ( + estimator in {EstimatorType.ClassifierMixin, EstimatorType.RegressorMixin} + ) + assert (not pilot.app.query_one("#predict_proba", Switch).disabled) == ( + estimator in {EstimatorType.ClassifierMixin, EstimatorType.OutlierMixin} + ) + + assert (not pilot.app.query_one("#decision_function", Switch).disabled) == ( + estimator == EstimatorType.ClassifierMixin + ) + + if estimator == EstimatorType.ClassifierMixin: + linear = pilot.app.query_one("#linear", Switch) + linear.value = True + + await pilot.pause() + assert pilot.app.query_one("#decision_function", Switch).disabled + + +async def test_valid_params() -> None: + """Test required and optional params interaction.""" + app = ForgeTUI() + required_ = "a,b" + optional_ = "c,d" + async with app.run_test(size=None) as pilot: + required_comp = pilot.app.query_one("#required", Input) + optional_comp = pilot.app.query_one("#optional", Input) + + required_comp.value = required_ + optional_comp.value = optional_ + + await required_comp.action_submit() + await optional_comp.action_submit() + await pilot.pause(0.01) + + notifications = list(pilot.app._notifications) # noqa: SLF001 + assert not notifications + + +@pytest.mark.parametrize(("required_", "optional_"), [("a,b", "a"), ("a", "a,b")]) +async def test_duplicated_params(required_: str, optional_: str) -> None: + app = ForgeTUI() + msg = "The following parameters are duplicated between required and optional: {'a'}" + + async with app.run_test(size=None) as pilot: + required_comp = pilot.app.query_one("#required", Input) + optional_comp = pilot.app.query_one("#optional", Input) + + required_comp.value = required_ + optional_comp.value = optional_ + + await required_comp.action_submit() + await optional_comp.action_submit() + await pilot.pause() + + forge_btn = pilot.app.query_one("#forge-btn", Button) + forge_btn.action_press() + await pilot.pause() + + assert all(msg in n.message for n in pilot.app._notifications) # noqa: SLF001 + + +async def test_forge_raise() -> None: + """Test forge button and all of its interactions.""" + app = ForgeTUI() + async with app.run_test(size=None) as pilot: + required_comp = pilot.app.query_one("#required", Input) + optional_comp = pilot.app.query_one("#optional", Input) + + required_comp.value = "a,a" + optional_comp.value = "b b" + + await required_comp.action_submit() + await optional_comp.action_submit() + await pilot.pause() + + forge_btn = pilot.app.query_one("#forge-btn", Button) + forge_btn.action_press() + await pilot.pause() + + m1, m2, m3 = (n.message for n in pilot.app._notifications) # noqa: SLF001 + + assert "Found repeated parameters!" in m1 + assert "The following parameters are invalid python identifiers: ('b b',)" in m2 + + assert "Name cannot be empty!" in m3 + assert "Estimator cannot be empty!" in m3 + assert "Outfile file cannot be empty!" in m3 + assert "Found repeated parameters!" in m3 + assert "The following parameters are invalid python identifiers: ('b b',)" in m3 + + +@pytest.mark.parametrize("use_binding", [True, False]) +async def test_forge(tmp_path: Path, use_binding: bool) -> None: + """Test forge button and all of its interactions.""" + app = ForgeTUI() + name = "MightyEstimator" + estimator = "classifier" + async with app.run_test(size=None) as pilot: + name_comp = pilot.app.query_one("#name", Input) + estimator_comp = pilot.app.query_one("#estimator", Select) + await pilot.pause() + + output_file_comp = pilot.app.query_one("#output-file", Input) + + name_comp.value = name + estimator_comp.value = estimator + + await pilot.pause() + + output_file = tmp_path / (f"{name.lower()}.py") + output_file_comp.value = str(output_file) + await output_file_comp.action_submit() + await pilot.pause() + + if use_binding: + await pilot.press("F") + else: + forge_btn = pilot.app.query_one("#forge-btn", Button) + forge_btn.action_press() + await pilot.pause() + + notification = next(iter(pilot.app._notifications)) # noqa: SLF001 + + assert f"Template forged at {output_file!s}" in notification.message + assert output_file.exists()