Skip to content

Commit

Permalink
Merge pull request #14 from Ensembl/jalvarez/update_dbc
Browse files Browse the repository at this point in the history
Update database connection and plan next release
  • Loading branch information
JAlvarezJarreta authored Jul 11, 2024
2 parents f4a09eb + 6627f2c commit dbf2f2c
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 55 deletions.
2 changes: 1 addition & 1 deletion src/ensembl/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
"""Ensembl Python general-purpose utils library."""

__version__ = "0.3.0"
__version__ = "0.4.0"

__all__ = [
"StrPath",
Expand Down
21 changes: 1 addition & 20 deletions src/ensembl/utils/database/dbconnection.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from typing import ContextManager, Generator, Optional, TypeVar

import sqlalchemy
from sqlalchemy import create_engine, event, text
from sqlalchemy import create_engine, event
from sqlalchemy.orm import sessionmaker
from sqlalchemy.schema import MetaData, Table

Expand Down Expand Up @@ -152,25 +152,6 @@ def dispose(self) -> None:
"""Disposes of the connection pool."""
self._engine.dispose()

def execute(self, statement: Query, parameters=None, execution_options=None) -> sqlalchemy.engine.Result:
"""Executes the given SQL query and returns its result.
See `sqlalchemy.engine.Connection.execute()` method for more information about the
additional arguments.
Args:
statement: SQL query to execute.
parameters: Parameters which will be bound into the statement.
execution_options: Optional dictionary of execution options, which will be associated
with the statement execution.
"""
if isinstance(statement, str):
statement = text(statement) # type: ignore[assignment]
return self.connect().execute(
statement=statement, parameters=parameters, execution_options=execution_options
) # type: ignore[call-overload]

def _enable_sqlite_savepoints(self, engine: sqlalchemy.engine.Engine) -> None:
"""Enables SQLite SAVEPOINTS to allow session rollbacks."""

Expand Down
39 changes: 5 additions & 34 deletions tests/database/test_dbconnection.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,19 @@
# limitations under the License.
"""Unit testing of `ensembl.utils.database.dbconnection` module."""

from contextlib import nullcontext as does_not_raise
import os
from pathlib import Path
from typing import ContextManager

import pytest
from pytest import FixtureRequest, param, raises
from pytest import FixtureRequest, param
from sqlalchemy import text, VARCHAR
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from sqlalchemy.engine.url import make_url
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.automap import automap_base
from sqlalchemy_utils import create_database, database_exists, drop_database

from ensembl.utils.database import DBConnection, Query, UnitTestDB
from ensembl.utils.database import DBConnection, UnitTestDB


class MockBase(DeclarativeBase):
Expand Down Expand Up @@ -157,34 +155,7 @@ def test_dispose(self) -> None:
num_conn = self.dbc._engine.pool.checkedin() # pylint: disable=protected-access
assert num_conn == 0, "A new pool should have 0 checked-in connections"

@pytest.mark.dependency(name="test_exec", depends=["test_init"], scope="class")
@pytest.mark.parametrize(
"query, nrows, expectation",
[
param("SELECT * FROM gibberish", 6, does_not_raise(), id="Valid string query"),
param(text("SELECT * FROM gibberish"), 6, does_not_raise(), id="Valid text query"),
param(
"SELECT * FROM my_table",
0,
raises(SQLAlchemyError, match=r"(my_table.* doesn't exist|no such table: my_table)"),
id="Querying an unexistent table",
),
],
)
def test_execute(self, query: Query, nrows: int, expectation: ContextManager) -> None:
"""Tests `DBConnection.execute()` method.
Args:
query: SQL query.
nrows: Number of rows expected to be returned from the query.
expectation: Context manager for the expected exception.
"""
with expectation:
result = self.dbc.execute(query)
assert len(result.fetchall()) == nrows, "Unexpected number of rows returned"

@pytest.mark.dependency(depends=["test_init", "test_connect", "test_exec"], scope="class")
@pytest.mark.dependency(depends=["test_init", "test_connect"], scope="class")
@pytest.mark.parametrize(
"identifier, rows_to_add, before, after",
[
Expand Down Expand Up @@ -232,7 +203,7 @@ def test_session_scope(
results = session.execute(query)
assert len(results.fetchall()) == after

@pytest.mark.dependency(depends=["test_init", "test_connect", "test_exec"], scope="class")
@pytest.mark.dependency(depends=["test_init", "test_connect"], scope="class")
def test_test_session_scope(self) -> None:
"""Tests `DBConnection.test_session_scope()` method."""
# Session requires mapped classes to interact with the database
Expand Down

0 comments on commit dbf2f2c

Please sign in to comment.