Skip to content

Commit

Permalink
update compatibility with sqlalchemy 2 (#1123)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidism authored Oct 14, 2022
2 parents 88cec96 + a3dcc0d commit 7f94bca
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 3 deletions.
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ Version 3.0.2

Unreleased

- Update compatibility with SQLAlchemy 2. :issue:`1122`


Version 3.0.1
-------------
Expand Down
5 changes: 4 additions & 1 deletion src/flask_sqlalchemy/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,8 +978,11 @@ def __getattr__(self, name: str) -> t.Any:
if name == "event":
return sa.event

if name.startswith("_"):
raise AttributeError(name)

for mod in (sa, sa.orm):
if name in mod.__all__:
if hasattr(mod, name):
return getattr(mod, name)

raise AttributeError(name)
2 changes: 1 addition & 1 deletion tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def test_sqlite_relative_path(app: Flask) -> None:
app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///test.db"
db = SQLAlchemy(app)
db.create_all()
assert isinstance(db.engine.pool, sa.pool.NullPool)
assert not isinstance(db.engine.pool, sa.pool.StaticPool)
db_path = db.engine.url.database
assert db_path.startswith(app.instance_path) # type: ignore[union-attr]
assert os.path.exists(db_path) # type: ignore[arg-type]
Expand Down
13 changes: 13 additions & 0 deletions tests/test_legacy_query.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,29 @@
from __future__ import annotations

import typing as t
import warnings

import pytest
import sqlalchemy as sa
import sqlalchemy.exc
from flask import Flask
from werkzeug.exceptions import NotFound

from flask_sqlalchemy import SQLAlchemy
from flask_sqlalchemy.query import Query


@pytest.fixture(autouse=True)
def ignore_query_warning() -> t.Generator[None, None, None]:
if hasattr(sa.exc, "LegacyAPIWarning"):
with warnings.catch_warnings():
exc = sa.exc.LegacyAPIWarning # type: ignore[attr-defined]
warnings.simplefilter("ignore", exc)
yield
else:
yield


@pytest.mark.usefixtures("app_ctx")
def test_get_or_404(db: SQLAlchemy, Todo: t.Any) -> None:
item = Todo()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_model_name.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ class Duck(db.Model):

class IdMixin:
@sa.orm.declared_attr
def id(cls) -> sa.Column[sa.Integer]: # noqa: B902
def id(cls): # type: ignore[no-untyped-def] # noqa: B902
return sa.Column(sa.Integer, sa.ForeignKey(Duck.id), primary_key=True)

class RubberDuck(IdMixin, Duck): # type: ignore[misc]
Expand Down

0 comments on commit 7f94bca

Please sign in to comment.