Skip to content

Commit 091dc7c

Browse files
committed
Disabling SQL compilation cache for Solr releases earlier than 9.0
1 parent 20fb56e commit 091dc7c

11 files changed

Lines changed: 329 additions & 82 deletions

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ To connect to Solr with SQLAlchemy, the following URL pattern can be used:
2828
solr://<username>:<password>@<host>:<port>/solr/<collection>[?parameter=value]
2929
```
3030

31+
_Note_: port 8983 is used when `port` in the URL is omitted
32+
3133
### Authentication
3234

3335
#### Basic Authentication
@@ -162,6 +164,7 @@ translates to `[2024-01-01T00:00:00Z TO *]`
162164
| Aliases |||||||
163165
| Built-in date range compilation |||||||
164166
| `SELECT` _expression_ statements |||||||
167+
| SQL compilation caching |||||||
165168

166169
## Use Cases
167170

src/sqlalchemy_solr/admin/solr_spec.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,51 @@
11
from requests import Session
2+
from sqlalchemy_solr import defaults
23

34

45
class SolrSpec:
56

67
_spec = None
78

8-
def __init__(self, solr_base_url):
9+
def __init__(self, url):
10+
"""
11+
Initializes a SolrSpec object
12+
13+
:param url: Solr base url which can be a string HTTP(S) URL or a sqlalchemy.engine.url.URL.
14+
"""
15+
916
session = Session()
17+
18+
if isinstance(url, str):
19+
base_url = url
20+
else:
21+
if "verify_ssl" in url.query and url.query["verify_ssl"] in [
22+
"False",
23+
"false",
24+
]:
25+
session.verify = False
26+
27+
token = None
28+
if "token" in url.query:
29+
token = url.query["token"]
30+
31+
if token is not None:
32+
session.headers.update({"Authorization": f"Bearer {token}"})
33+
else:
34+
session.auth = (url.username, url.password)
35+
36+
proto = "http"
37+
if "use_ssl" in url.query and url.query["use_ssl"] in ["True", "true"]:
38+
proto = "https"
39+
40+
server_path = url.database.split("/")[0]
41+
42+
port = url.port or defaults.PORT
43+
base_url = f"{proto}://{url.host}:{port}/{server_path}"
44+
1045
sys_info_response = session.get(
11-
solr_base_url + "/admin/info/system", params={"wt": "json"}
46+
base_url + "/admin/info/system", params={"wt": "json"}
1247
)
48+
1349
spec_version = sys_info_response.json()["lucene"]["solr-spec-version"]
1450
self._spec = list(map(int, spec_version.split(".")))
1551

src/sqlalchemy_solr/base.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@
3030
from sqlalchemy.sql import expression
3131
from sqlalchemy.sql import operators
3232
from sqlalchemy.sql.expression import BindParameter
33+
from sqlalchemy_solr import release_flags
3334

3435
from . import solrdbapi as module
3536
from .solr_type_compiler import SolrTypeCompiler
36-
from .solrdbapi import Connection
3737
from .type_map import metadata_type_map
3838

3939
logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.ERROR)
@@ -42,8 +42,6 @@
4242
class SolrCompiler(compiler.SQLCompiler):
4343
# pylint: disable=abstract-method
4444

45-
SOLR_DATE_RANGE_TRANS_RELEASE = 9
46-
4745
merge_ops = (operators.ge, operators.gt, operators.le, operators.lt)
4846
bounds = {
4947
operators.ge: "[",
@@ -70,7 +68,10 @@ def visit_binary(
7068
):
7169

7270
# Handled in Solr 9
73-
if Connection.solr_spec.spec()[0] >= self.SOLR_DATE_RANGE_TRANS_RELEASE:
71+
if (
72+
SolrDialect.solr_spec.spec()[0]
73+
>= release_flags.SOLR_DATE_RANGE_TRANS_RELEASE
74+
):
7475
return super().visit_binary(binary, override_operator, eager_grouping, **kw)
7576

7677
if binary.operator not in self.merge_ops:
@@ -157,7 +158,10 @@ def visit_binary(
157158

158159
def visit_clauselist(self, clauselist, **kw):
159160
# Handled in Solr 9
160-
if Connection.solr_spec.spec()[0] >= self.SOLR_DATE_RANGE_TRANS_RELEASE:
161+
if (
162+
SolrDialect.solr_spec.spec()[0]
163+
>= release_flags.SOLR_DATE_RANGE_TRANS_RELEASE
164+
):
161165
return super().visit_clauselist(clauselist, **kw)
162166

163167
if clauselist.operator == operators.and_:
@@ -537,6 +541,8 @@ class SolrDialect(default.DefaultDialect):
537541
supports_native_boolean = True
538542
supports_statement_cache = True
539543

544+
solr_spec = None
545+
540546
def __init__(self, **kw):
541547
default.DefaultDialect.__init__(self, **kw)
542548
self.supported_extensions = []

src/sqlalchemy_solr/defaults.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
PORT = 8983

src/sqlalchemy_solr/http.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323

2424
from requests import RequestException
2525
from requests import Session
26+
from sqlalchemy_solr import defaults
27+
from sqlalchemy_solr import release_flags
28+
from sqlalchemy_solr.admin.solr_spec import SolrSpec
2629
from sqlalchemy_solr.solrdbapi.api_exceptions import DatabaseError
2730

2831
from .api_globals import _HEADER
@@ -86,7 +89,7 @@ def create_connect_args(self, url):
8689

8790
# Save this for later use.
8891
self.host = url.host
89-
self.port = url_port
92+
self.port = url.port or defaults.PORT
9093
self.username = url.username
9194
self.password = url.password
9295
self.db = db
@@ -96,11 +99,7 @@ def create_connect_args(self, url):
9699
# Prepare a session with proper authorization handling.
97100
session = Session()
98101
# session.verify property which is bydefault true so Handled here
99-
if "verify_ssl" in url.query and url.query["verify_ssl"] in [
100-
False,
101-
"False",
102-
"false",
103-
]:
102+
if "verify_ssl" in url.query and url.query["verify_ssl"] in ["False", "false"]:
104103
session.verify = False
105104

106105
if self.token is not None:
@@ -209,3 +208,22 @@ def get_unique_columns(self, columns):
209208
columns_set.remove(c["name"])
210209

211210
return unique_columns
211+
212+
def on_connect_url(self, url):
213+
SolrDialect.solr_spec = SolrSpec(url)
214+
215+
def do_on_connect(connection): # pylint: disable=unused-argument
216+
SolrDialect.solr_spec = SolrSpec(url)
217+
218+
if (
219+
SolrDialect.solr_spec.spec()[0]
220+
< release_flags.SOLR_DATE_RANGE_TRANS_RELEASE
221+
):
222+
logging.warning(
223+
"Solr version %s less than 9, SQL compilation cache disabled",
224+
SolrDialect.solr_spec.spec()[0],
225+
)
226+
SolrDialect_http.supports_statement_cache = False
227+
SolrDialect.supports_statement_cache = False
228+
229+
return do_on_connect
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
SOLR_DATE_RANGE_TRANS_RELEASE = 9

src/sqlalchemy_solr/solrdbapi/_solrdbapi.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import logging
22

33
from requests import Session
4+
from sqlalchemy_solr import defaults
45

56
from .. import type_map
6-
from ..admin.solr_spec import SolrSpec
77
from ..api_globals import _HEADER
88
from ..api_globals import _PAYLOAD
99
from ..message_formatter import MessageFormatter
@@ -286,7 +286,6 @@ def __iter__(self):
286286
class Connection:
287287
# pylint: disable=too-many-instance-attributes
288288

289-
solr_spec = None
290289
mf = MessageFormatter()
291290

292291
# pylint: disable=too-many-arguments
@@ -314,8 +313,6 @@ def __init__(
314313
self._session = session
315314
self._connected = True
316315

317-
Connection.solr_spec = SolrSpec(f"{proto}{host}:{port}/{server_path}")
318-
319316
SolrTableReflection.connection = self
320317

321318
@property
@@ -374,23 +371,23 @@ def cursor(self):
374371
# pylint: disable=too-many-arguments
375372
def connect(
376373
host,
377-
port=8047,
378-
db=None,
374+
db,
375+
server_path,
376+
collection,
377+
port=defaults.PORT,
379378
username=None,
380379
password=None,
381-
server_path="solr",
382-
collection=None,
383-
use_ssl=False,
380+
use_ssl=None,
384381
verify_ssl=None,
385382
token=None,
386383
):
387384

388385
session = Session()
389386
# bydefault session.verify is set to True
390-
if verify_ssl is not None and verify_ssl in [False, "False", "false"]:
387+
if verify_ssl is not None and verify_ssl in ["False", "false"]:
391388
session.verify = False
392389

393-
if use_ssl in [True, "True", "true"]:
390+
if use_ssl in ["True", "true"]:
394391
proto = "https://"
395392
else:
396393
proto = "http://"

tests/assertions.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import pytest
2+
from sqlalchemy_solr.admin.solr_spec import SolrSpec
3+
4+
5+
def assert_solr_release(settings, releases):
6+
solr_spec = SolrSpec(settings["SOLR_BASE_URL"])
7+
if solr_spec.spec()[0] not in releases:
8+
pytest.skip(
9+
reason=f"Solr spec version {solr_spec} not compatible with the current test"
10+
)

tests/test_sql_compilation_caching.py

Lines changed: 0 additions & 58 deletions
This file was deleted.
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from sqlalchemy import select
2+
from sqlalchemy.sql.expression import bindparam
3+
from sqlalchemy.util.langhelpers import _symbol
4+
from tests import assertions
5+
from tests.setup import prepare_orm
6+
7+
releases = [6, 7, 8]
8+
9+
10+
class TestSQLCompilationCaching:
11+
12+
def test_sql_compilation_caching_1(self, settings):
13+
assertions.assert_solr_release(settings, releases)
14+
15+
engine, t = prepare_orm(settings)
16+
17+
qry_1 = (select(t.c.COUNTRY_s).select_from(t)).limit(1)
18+
qry_2 = (select(t.c.COUNTRY_s).select_from(t)).limit(10)
19+
20+
with engine.connect() as connection:
21+
result_1 = connection.execute(qry_1)
22+
result_2 = connection.execute(qry_2)
23+
24+
assert result_1.context.cache_hit == _symbol("NO_DIALECT_SUPPORT")
25+
assert result_2.context.cache_hit == _symbol("NO_DIALECT_SUPPORT")
26+
27+
def test_sql_compilation_caching_2(self, settings):
28+
assertions.assert_solr_release(settings, releases)
29+
30+
engine, t = prepare_orm(settings)
31+
32+
qry_1 = (select(t.c.COUNTRY_s).select_from(t)).limit(1).offset(1)
33+
qry_2 = (select(t.c.COUNTRY_s).select_from(t)).limit(1).offset(2)
34+
35+
with engine.connect() as connection:
36+
result_1 = connection.execute(qry_1)
37+
result_2 = connection.execute(qry_2)
38+
39+
assert result_1.context.cache_hit == _symbol("NO_DIALECT_SUPPORT")
40+
assert result_2.context.cache_hit == _symbol("NO_DIALECT_SUPPORT")
41+
42+
def test_sql_compilation_caching_3(self, settings):
43+
assertions.assert_solr_release(settings, releases)
44+
45+
engine, t = prepare_orm(settings)
46+
47+
qry = select(t).where(t.c.COUNTRY_s == bindparam("COUNTRY_s")).limit(10)
48+
49+
with engine.connect() as connection:
50+
result_1 = connection.execute(qry, {"COUNTRY_s": "Sweden"})
51+
result_2 = connection.execute(qry, {"COUNTRY_s": "France"})
52+
53+
assert result_1.context.cache_hit == _symbol("NO_DIALECT_SUPPORT")
54+
assert result_2.context.cache_hit == _symbol("NO_DIALECT_SUPPORT")

0 commit comments

Comments
 (0)