|
1 | 1 | from sqlalchemy import select, func |
| 2 | +from datetime import datetime, date |
2 | 3 | import pytest |
3 | 4 |
|
4 | 5 | from models import Item, Product |
@@ -117,6 +118,91 @@ def test_json_extract_filter(self, SessionFactory, pattern, value, expected_ids) |
117 | 118 | results = session.execute(stmt).scalars().all() |
118 | 119 | assert {item.id for item in results} == expected_ids |
119 | 120 |
|
| 121 | + def test_default_values(self, SessionFactory): |
| 122 | + dt = datetime(2025, 1, 1, 2, 3, 4) |
| 123 | + |
| 124 | + with SessionFactory() as session: |
| 125 | + session.add_all([ |
| 126 | + Product(id=5, name="foo", category="A"), |
| 127 | + Product(name="bar", active=False, created_at=dt), |
| 128 | + ]) |
| 129 | + session.commit() |
| 130 | + |
| 131 | + products = session.execute(select(Product)).scalars().all() |
| 132 | + |
| 133 | + assert products[0].active |
| 134 | + assert products[0].created_at is not None |
| 135 | + assert products[0].category == "A" |
| 136 | + |
| 137 | + assert products[1].id == 6 |
| 138 | + assert not products[1].active |
| 139 | + assert products[1].created_at == dt |
| 140 | + assert products[1].category == "unknown" |
| 141 | + |
| 142 | + @pytest.mark.parametrize( |
| 143 | + "operator, value, expected_ids", |
| 144 | + [ |
| 145 | + ("is", True, {1, 3}), |
| 146 | + ("is_not", True, {2}), |
| 147 | + ("is", False, {2}), |
| 148 | + ("is_not", False, {1, 3}), |
| 149 | + ] |
| 150 | + ) |
| 151 | + def test_is_filter(self, SessionFactory, operator, value, expected_ids): |
| 152 | + with SessionFactory() as session: |
| 153 | + session.add_all([ |
| 154 | + Product(id=1, name="foo", active=True), |
| 155 | + Product(id=2, name="bar", active=False), |
| 156 | + Product(id=3, name="foobar", active=True), |
| 157 | + ]) |
| 158 | + session.commit() |
| 159 | + |
| 160 | + stmt = select(Product) |
| 161 | + if operator == "is": |
| 162 | + stmt = stmt.where(Product.active.is_(value)) |
| 163 | + else: |
| 164 | + stmt = stmt.where(Product.active.is_not(value)) |
| 165 | + |
| 166 | + results = session.execute(stmt).scalars().all() |
| 167 | + assert {item.id for item in results} == expected_ids |
| 168 | + |
| 169 | + @pytest.mark.parametrize( |
| 170 | + "operator, value, expected_ids", |
| 171 | + [ |
| 172 | + ("==", date(2025, 1, 1), {1}), |
| 173 | + ("!=", date(2025, 1, 2), {1, 3, 4}), |
| 174 | + (">", date(2025, 1, 2), {3, 4}), |
| 175 | + (">", date(2025, 1, 10), set()), |
| 176 | + ] |
| 177 | + ) |
| 178 | + def test_date_filter(self, SessionFactory, operator, value, expected_ids): |
| 179 | + """ |
| 180 | + stmt = select(NetLiquidationValue).where( |
| 181 | + func.DATE(NetLiquidationValue.created_at) == _date |
| 182 | + ) |
| 183 | +
|
| 184 | + """ |
| 185 | + |
| 186 | + with SessionFactory() as session: |
| 187 | + session.add_all([ |
| 188 | + Product(id=1, name="foo", created_at=datetime(2025, 1, 1, 1, 1, 1)), |
| 189 | + Product(id=2, name="bar", created_at=datetime(2025, 1, 2, 2, 2, 2)), |
| 190 | + Product(id=3, name="foobar", created_at=datetime(2025, 1, 3, 3, 3, 3)), |
| 191 | + Product(id=4, name="barfoo", created_at=datetime(2025, 1, 4, 4, 4, 4)), |
| 192 | + ]) |
| 193 | + session.commit() |
| 194 | + |
| 195 | + stmt = select(Product) |
| 196 | + if operator == "==": |
| 197 | + stmt = stmt.where(func.DATE(Product.created_at) == value) |
| 198 | + elif operator == "!=": |
| 199 | + stmt = stmt.where(func.DATE(Product.created_at) != value) |
| 200 | + elif operator == ">": |
| 201 | + stmt = stmt.where(func.DATE(Product.created_at) > value) |
| 202 | + |
| 203 | + results = session.execute(stmt).scalars().all() |
| 204 | + assert {item.id for item in results} == expected_ids |
| 205 | + |
120 | 206 | def test_session_inception(self, SessionFactory): |
121 | 207 | with SessionFactory() as session1: |
122 | 208 | session1.add(Item(id=1, name="foo")) |
|
0 commit comments