Skip to content

Commit de34b87

Browse files
committed
Migrate app code to SQLAlchemy 2.0 Query API
Replace all legacy session.query() and Model.query patterns with session.execute(select()), session.scalars(), session.scalar(), and db.paginate() across 10 files: flowapp/__init__.py, auth.py, messages.py flowapp/models/utils.py, rules/whitelist.py, community.py, user.py flowapp/services/base.py, rule_service.py, whitelist_service.py Add test coverage for previously untested DB-touching code paths: tests/test_auth.py, test_model_utils.py, test_services_base.py (new files), plus additions to test_flowapp.py, test_messages.py, test_models.py, test_whitelist_service.py
1 parent d1f8852 commit de34b87

17 files changed

Lines changed: 879 additions & 106 deletions

flowapp/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33

44
from flask import Flask, redirect, render_template, session, url_for, flash
5+
from sqlalchemy import select
56

67
from flask_sso import SSO
78
from flask_sqlalchemy import SQLAlchemy
@@ -128,7 +129,7 @@ def index():
128129
@auth_required
129130
def select_org(org_id=None):
130131
uuid = session.get("user_uuid")
131-
user = db.session.query(models.User).filter_by(uuid=uuid).first()
132+
user = db.session.execute(select(models.User).filter_by(uuid=uuid)).scalar_one_or_none()
132133

133134
if user is None:
134135
return render_template("errors/404.html"), 404

flowapp/auth.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from functools import wraps
22
from typing import List, Optional
33
from flask import current_app, redirect, request, url_for, session, abort
4+
from sqlalchemy import select
45

56
from flowapp import __version__, db, validators
67
from flowapp.models import Flowspec4, Flowspec6, RTBH, Whitelist, get_user_nets
@@ -153,35 +154,35 @@ def get_user_allowed_rule_ids(rule_type: str, user_id: int, user_role_ids: List[
153154
# Admin users can modify any rules
154155
if 3 in user_role_ids:
155156
if rule_type == "ipv4":
156-
return [r.id for r in db.session.query(Flowspec4.id).all()]
157+
return list(db.session.scalars(select(Flowspec4.id)))
157158
elif rule_type == "ipv6":
158-
return [r.id for r in db.session.query(Flowspec6.id).all()]
159+
return list(db.session.scalars(select(Flowspec6.id)))
159160
elif rule_type == "rtbh":
160-
return [r.id for r in db.session.query(RTBH.id).all()]
161+
return list(db.session.scalars(select(RTBH.id)))
161162
elif rule_type == "whitelist":
162-
return [r.id for r in db.session.query(Whitelist.id).all()]
163+
return list(db.session.scalars(select(Whitelist.id)))
163164
return []
164165

165166
# Regular users - filter by network ranges
166167
net_ranges = get_user_nets(user_id)
167168

168169
if rule_type == "ipv4":
169-
rules = db.session.query(Flowspec4).all()
170+
rules = db.session.scalars(select(Flowspec4)).all()
170171
filtered_rules = validators.filter_rules_in_network(net_ranges, rules)
171172
return [r.id for r in filtered_rules]
172173

173174
elif rule_type == "ipv6":
174-
rules = db.session.query(Flowspec6).all()
175+
rules = db.session.scalars(select(Flowspec6)).all()
175176
filtered_rules = validators.filter_rules_in_network(net_ranges, rules)
176177
return [r.id for r in filtered_rules]
177178

178179
elif rule_type == "rtbh":
179-
rules = db.session.query(RTBH).all()
180+
rules = db.session.scalars(select(RTBH)).all()
180181
filtered_rules = validators.filter_rtbh_rules(net_ranges, rules)
181182
return [r.id for r in filtered_rules]
182183

183184
elif rule_type == "whitelist":
184-
rules = db.session.query(Whitelist).all()
185+
rules = db.session.scalars(select(Whitelist)).all()
185186
filtered_rules = validators.filter_rules_in_network(net_ranges, rules)
186187
return [r.id for r in filtered_rules]
187188

flowapp/messages.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
)
1111
from flowapp.flowspec import translate_sequence as trps
1212
from flask import current_app
13+
from sqlalchemy import select
1314
from flowapp.models import ASPath
1415
from flowapp import db
1516

@@ -138,7 +139,7 @@ def create_rtbh(rule, message_type=ANNOUNCE):
138139

139140
as_path_string = ""
140141
if rule.community.as_path:
141-
match = db.session.query(ASPath).filter(ASPath.prefix == source).first()
142+
match = db.session.execute(select(ASPath).filter_by(prefix=source)).scalar_one_or_none()
142143
as_path_string = f"as-path [ {match.as_path} ]" if match else ""
143144

144145
return "{neighbor}{action} route {source} next-hop {nexthop} {as_path} {community} {large_community} {extended_community}{rd_string}".format(

flowapp/models/community.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from sqlalchemy import event
1+
from sqlalchemy import event, select
22
from .base import db
33

44

@@ -26,7 +26,7 @@ def __init__(self, name, comm, larcomm, extcomm, description, as_path, role_id):
2626

2727
@classmethod
2828
def get_whitelistable_communities(cls, id_list):
29-
return cls.query.filter(cls.id.in_(id_list)).all()
29+
return db.session.scalars(select(cls).filter(cls.id.in_(id_list))).all()
3030

3131
def __repr__(self):
3232
return f"<Community {self.name}>"

flowapp/models/rules/whitelist.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from flowapp import utils
22
from ..base import db
33
from datetime import datetime
4+
from sqlalchemy import func, select
45
from flowapp.constants import RuleTypes, RuleOrigin
56

67

@@ -119,7 +120,7 @@ def get_by_whitelist_id(cls, whitelist_id: int):
119120
Returns:
120121
list: All RuleWhitelistCache objects with the specified whitelist_id
121122
"""
122-
return cls.query.filter_by(whitelist_id=whitelist_id).all()
123+
return db.session.scalars(select(cls).filter_by(whitelist_id=whitelist_id)).all()
123124

124125
@classmethod
125126
def clean_by_whitelist_id(cls, whitelist_id: int):
@@ -132,7 +133,9 @@ def clean_by_whitelist_id(cls, whitelist_id: int):
132133
Returns:
133134
int: Number of rows deleted
134135
"""
135-
deleted = cls.query.filter_by(whitelist_id=whitelist_id).delete()
136+
deleted = db.session.execute(
137+
db.delete(cls).filter_by(whitelist_id=whitelist_id)
138+
).rowcount
136139
db.session.commit()
137140
return deleted
138141

@@ -147,7 +150,9 @@ def delete_by_rule_id(cls, rule_id: int):
147150
Returns:
148151
int: Number of rows deleted
149152
"""
150-
deleted = cls.query.filter_by(rid=rule_id).delete()
153+
deleted = db.session.execute(
154+
db.delete(cls).filter_by(rid=rule_id)
155+
).rowcount
151156
db.session.commit()
152157
return deleted
153158

@@ -163,7 +168,9 @@ def count_by_rule(cls, rule_id: int, rule_type: RuleTypes):
163168
Returns:
164169
int: Number of cache entries
165170
"""
166-
return cls.query.filter_by(rid=rule_id, rtype=rule_type.value).count()
171+
return db.session.scalar(
172+
select(func.count()).select_from(cls).filter_by(rid=rule_id, rtype=rule_type.value)
173+
)
167174

168175
def __repr__(self):
169176
return f"<RuleWhitelistCache {self.rid} {self.rtype} {self.rorigin}>"

flowapp/models/user.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from sqlalchemy import event
1+
from sqlalchemy import event, select
22
from .base import db, user_role, user_organization
33
from .organization import Organization
44

@@ -46,12 +46,12 @@ def update(self, form):
4646
self.organization.remove(org)
4747

4848
for role_id in form.role_ids.data:
49-
my_role = db.session.query(Role).filter_by(id=role_id).first()
49+
my_role = db.session.execute(select(Role).filter_by(id=role_id)).scalar_one()
5050
if my_role not in self.role:
5151
self.role.append(my_role)
5252

5353
for org_id in form.org_ids.data:
54-
my_org = db.session.query(Organization).filter_by(id=org_id).first()
54+
my_org = db.session.execute(select(Organization).filter_by(id=org_id)).scalar_one()
5555
if my_org not in self.organization:
5656
self.organization.append(my_org)
5757

0 commit comments

Comments
 (0)