Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 1 addition & 37 deletions cassandra/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -1389,7 +1389,7 @@ def __init__(self,
HostDistance.REMOTE: DEFAULT_MAX_CONNECTIONS_PER_REMOTE_HOST
}

self.executor = self._create_thread_pool_executor(max_workers=executor_threads)
self.executor = ThreadPoolExecutor(max_workers=executor_threads)
self.scheduler = _Scheduler(self.executor)

self._lock = RLock()
Expand All @@ -1411,42 +1411,6 @@ def __init__(self,
if application_version is not None:
self.application_version = application_version

def _create_thread_pool_executor(self, **kwargs):
"""
Create a ThreadPoolExecutor for the cluster. In most cases, the built-in
`concurrent.futures.ThreadPoolExecutor` is used.

Python 3.7+ and Eventlet cause the `concurrent.futures.ThreadPoolExecutor`
to hang indefinitely. In that case, the user needs to have the `futurist`
package so we can use the `futurist.GreenThreadPoolExecutor` class instead.

:param kwargs: All keyword args are passed to the ThreadPoolExecutor constructor.
:return: A ThreadPoolExecutor instance.
"""
tpe_class = ThreadPoolExecutor
if sys.version_info[0] >= 3 and sys.version_info[1] >= 7:
try:
from cassandra.io.eventletreactor import EventletConnection
is_eventlet = issubclass(self.connection_class, EventletConnection)
except:
# Eventlet is not available or can't be detected
return tpe_class(**kwargs)

if is_eventlet:
try:
from futurist import GreenThreadPoolExecutor
tpe_class = GreenThreadPoolExecutor
except ImportError:
# futurist is not available
raise ImportError(
("Python 3.7+ and Eventlet cause the `concurrent.futures.ThreadPoolExecutor` "
"to hang indefinitely. If you want to use the Eventlet reactor, you "
"need to install the `futurist` package to allow the driver to use "
"the GreenThreadPoolExecutor. See https://github.com/eventlet/eventlet/issues/508 "
"for more details."))

return tpe_class(**kwargs)

def register_user_type(self, keyspace, user_type, klass):
"""
Registers a class to use to represent a particular user-defined type.
Expand Down
1 change: 0 additions & 1 deletion test-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,4 @@ gevent
eventlet
cython>=3.0
packaging
futurist
asynctest
7 changes: 0 additions & 7 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,6 @@ def is_monkey_patched():

from cassandra.io.eventletreactor import EventletConnection
connection_class = EventletConnection

try:
from futurist import GreenThreadPoolExecutor
thread_pool_executor_class = GreenThreadPoolExecutor
except:
# futurist is installed only with python >=3.7
pass
elif "asyncore" in EVENT_LOOP_MANAGER:
from cassandra.io.asyncorereactor import AsyncoreConnection
connection_class = AsyncoreConnection
Expand Down
5 changes: 1 addition & 4 deletions tests/integration/advanced/graph/fluent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,10 +585,7 @@ def _validate_prop(key, value, unittest):
typ = int

elif any(key.startswith(t) for t in ('long',)):
if sys.version_info >= (3, 0):
typ = int
else:
typ = long
typ = int
elif any(key.startswith(t) for t in ('float', 'double')):
typ = float
elif any(key.startswith(t) for t in ('polygon',)):
Expand Down
9 changes: 4 additions & 5 deletions tests/integration/cqlengine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,13 @@
# limitations under the License.
import unittest

import sys

from cassandra.cqlengine.connection import get_session
from cassandra.cqlengine.models import Model
from cassandra.cqlengine import columns

from uuid import uuid4


class TestQueryUpdateModel(Model):

partition = columns.UUID(primary_key=True, default=uuid4)
Expand All @@ -33,6 +32,7 @@ class TestQueryUpdateModel(Model):
text_list = columns.List(columns.Text, required=False)
text_map = columns.Map(columns.Text, columns.Text, required=False)


class BaseCassEngTestCase(unittest.TestCase):

session = None
Expand All @@ -48,6 +48,5 @@ def assertNotHasAttr(self, obj, attr):
self.assertFalse(hasattr(obj, attr),
"{0} shouldn't have the attribute: {1}".format(obj, attr))

if sys.version_info > (3, 0):
def assertItemsEqual(self, first, second, msg=None):
return self.assertCountEqual(first, second, msg)
def assertItemsEqual(self, first, second, msg=None):
return self.assertCountEqual(first, second, msg)
22 changes: 9 additions & 13 deletions tests/integration/cqlengine/columns/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def test_datetime_tzinfo_io(self):
class TZ(tzinfo):
def utcoffset(self, date_time):
return timedelta(hours=-1)

def dst(self, date_time):
return None

Expand Down Expand Up @@ -91,7 +92,7 @@ def test_datetime_none(self):
self.assertIsNone(dts[0][0])

def test_datetime_invalid(self):
dt_value= 'INVALID'
dt_value = 'INVALID'
with self.assertRaises(TypeError):
self.DatetimeTest.objects.create(test_id=4, created_at=dt_value)

Expand Down Expand Up @@ -125,7 +126,7 @@ def test_datetime_truncate_microseconds(self):
dt_truncated = datetime(2024, 12, 31, 10, 10, 10, 923000)
self.DatetimeTest.objects.create(test_id=6, created_at=dt_value)
dt2 = self.DatetimeTest.objects(test_id=6).first()
self.assertEqual(dt2.created_at,dt_truncated)
self.assertEqual(dt2.created_at, dt_truncated)
finally:
# We need to always return behavior to default
DateTime.truncate_microseconds = False
Expand Down Expand Up @@ -191,7 +192,7 @@ def test_varint_io(self):
self.VarIntTest.objects.create(test_id=0, bignum="not_a_number")


class DataType():
class DataType:
@classmethod
def setUpClass(cls):
if PROTOCOL_VERSION < 4 or CASSANDRA_VERSION < Version("3.0"):
Expand Down Expand Up @@ -344,6 +345,7 @@ def setUpClass(cls):
)
super(TestBoolean, cls).setUpClass()


@greaterthanorequalcass3_11
class TestDuration(DataType, BaseCassEngTestCase):
@classmethod
Expand Down Expand Up @@ -507,7 +509,7 @@ def test_timeuuid_io(self):
class TestInteger(BaseCassEngTestCase):
class IntegerTest(Model):

test_id = UUID(primary_key=True, default=lambda:uuid4())
test_id = UUID(primary_key=True, default=lambda: uuid4())
value = Integer(default=0, required=True)

def test_default_zero_fields_validate(self):
Expand All @@ -519,8 +521,8 @@ def test_default_zero_fields_validate(self):
class TestBigInt(BaseCassEngTestCase):
class BigIntTest(Model):

test_id = UUID(primary_key=True, default=lambda:uuid4())
value = BigInt(default=0, required=True)
test_id = UUID(primary_key=True, default=lambda: uuid4())
value = BigInt(default=0, required=True)

def test_default_zero_fields_validate(self):
""" Tests that bigint columns with a default value of 0 validate """
Expand Down Expand Up @@ -612,10 +614,6 @@ def test_type_checking(self):
with self.assertRaises(ValidationError):
Ascii().validate('Beyonc' + chr(233))

if sys.version_info < (3, 1):
with self.assertRaises(ValidationError):
Ascii().validate(u'Beyonc' + unichr(233))

def test_unaltering_validation(self):
""" Test the validation step doesn't re-interpret values. """
self.assertEqual(Ascii().validate(''), '')
Expand Down Expand Up @@ -736,8 +734,6 @@ def test_type_checking(self):

Text().validate("!#$%&\'()*+,-./")
Text().validate('Beyonc' + chr(233))
if sys.version_info < (3, 1):
Text().validate(u'Beyonc' + unichr(233))

def test_unaltering_validation(self):
""" Test the validation step doesn't re-interpret values. """
Expand Down Expand Up @@ -810,7 +806,7 @@ def test_conversion_specific_date(self):
from uuid import UUID
assert isinstance(uuid, UUID)

ts = (uuid.time - 0x01b21dd213814000) / 1e7 # back to a timestamp
ts = (uuid.time - 0x01b21dd213814000) / 1e7 # back to a timestamp
new_dt = datetime.fromtimestamp(ts, tz=timezone.utc).replace(tzinfo=None)

# checks that we created a UUID1 with the proper timestamp
Expand Down
14 changes: 3 additions & 11 deletions tests/integration/standard/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -1087,14 +1087,6 @@ def test_export_keyspace_schema_udts(self):
Test udt exports
"""

if PROTOCOL_VERSION < 3:
raise unittest.SkipTest(
"Protocol 3.0+ is required for UDT change events, currently testing against %r"
% (PROTOCOL_VERSION,))

if sys.version_info[0:2] != (2, 7):
raise unittest.SkipTest('This test compares static strings generated from dict items, which may change orders. Test with 2.7.')

cluster = TestCluster()
session = cluster.connect()

Expand Down Expand Up @@ -1591,7 +1583,7 @@ def test_function_no_parameters(self):

with self.VerifiedFunction(self, **kwargs) as vf:
fn_meta = self.keyspace_function_meta[vf.signature]
self.assertRegex(fn_meta.as_cql_query(), "CREATE FUNCTION.*%s\(\) .*" % kwargs['name'])
self.assertRegex(fn_meta.as_cql_query(), r"CREATE FUNCTION.*%s\(\) .*" % kwargs['name'])

def test_functions_follow_keyspace_alter(self):
"""
Expand Down Expand Up @@ -1639,12 +1631,12 @@ def test_function_cql_called_on_null(self):
kwargs['called_on_null_input'] = True
with self.VerifiedFunction(self, **kwargs) as vf:
fn_meta = self.keyspace_function_meta[vf.signature]
self.assertRegex(fn_meta.as_cql_query(), "CREATE FUNCTION.*\) CALLED ON NULL INPUT RETURNS .*")
self.assertRegex(fn_meta.as_cql_query(), r"CREATE FUNCTION.*\) CALLED ON NULL INPUT RETURNS .*")

kwargs['called_on_null_input'] = False
with self.VerifiedFunction(self, **kwargs) as vf:
fn_meta = self.keyspace_function_meta[vf.signature]
self.assertRegex(fn_meta.as_cql_query(), "CREATE FUNCTION.*\) RETURNS NULL ON NULL INPUT RETURNS .*")
self.assertRegex(fn_meta.as_cql_query(), r"CREATE FUNCTION.*\) RETURNS NULL ON NULL INPUT RETURNS .*")


class AggregateMetadata(FunctionTest):
Expand Down
22 changes: 10 additions & 12 deletions tests/unit/advanced/test_insights.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import unittest

import logging
import sys
from unittest.mock import sentinel

from cassandra import ConsistencyLevel
Expand All @@ -32,7 +31,6 @@
from cassandra.datastax.graph.query import GraphOptions
from cassandra.datastax.insights.registry import insights_registry
from cassandra.datastax.insights.serializers import initialize_registry
from cassandra.datastax.insights.util import namespace
from cassandra.policies import (
RoundRobinPolicy,
LoadBalancingPolicy,
Expand Down Expand Up @@ -63,8 +61,7 @@ class NoConfAsDict(object):
obj = NoConfAsDict()

ns = 'tests.unit.advanced.test_insights'
if sys.version_info > (3,):
ns += '.TestGetConfig.test_invalid_object.<locals>'
ns += '.TestGetConfig.test_invalid_object.<locals>'

# no default
# ... as a policy
Expand Down Expand Up @@ -102,6 +99,7 @@ def superclass_sentinel_serializer(obj):
self.assertIs(insights_registry.serialize(SubclassSentinel(), default=object()),
sentinel.serialized_superclass)


class TestConfigAsDict(unittest.TestCase):

# graph/query.py
Expand Down Expand Up @@ -253,35 +251,35 @@ def test_constant_reconnection_policy(self):
self.assertEqual(
insights_registry.serialize(ConstantReconnectionPolicy(3, 200)),
{'type': 'ConstantReconnectionPolicy',
'namespace': 'cassandra.policies',
'options': {'delay': 3, 'max_attempts': 200}
'namespace': 'cassandra.policies',
'options': {'delay': 3, 'max_attempts': 200}
}
)

def test_exponential_reconnection_policy(self):
self.assertEqual(
insights_registry.serialize(ExponentialReconnectionPolicy(4, 100, 10)),
{'type': 'ExponentialReconnectionPolicy',
'namespace': 'cassandra.policies',
'options': {'base_delay': 4, 'max_delay': 100, 'max_attempts': 10}
'namespace': 'cassandra.policies',
'options': {'base_delay': 4, 'max_delay': 100, 'max_attempts': 10}
}
)

def test_retry_policy(self):
self.assertEqual(
insights_registry.serialize(RetryPolicy()),
{'type': 'RetryPolicy',
'namespace': 'cassandra.policies',
'options': {}
'namespace': 'cassandra.policies',
'options': {}
}
)

def test_spec_exec_policy(self):
self.assertEqual(
insights_registry.serialize(SpeculativeExecutionPolicy()),
{'type': 'SpeculativeExecutionPolicy',
'namespace': 'cassandra.policies',
'options': {}
'namespace': 'cassandra.policies',
'options': {}
}
)

Expand Down
7 changes: 3 additions & 4 deletions tests/unit/test_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,7 +909,7 @@ def test_schedule_with_max(self):
if i == 0:
self._assert_between(delay, base_delay*0.85, base_delay*1.15)
elif i < 6:
value = base_delay * (2 ** i)
value = base_delay * (2 ** i)
self._assert_between(delay, value*85/100, value*1.15)
else:
self._assert_between(delay, max_delay*85/100, max_delay*1.15)
Expand Down Expand Up @@ -956,7 +956,7 @@ def test_schedule_with_jitter(self):
"""
for i in range(100):
base_delay = float(randint(2, 5))
max_delay = (base_delay - 1) * 100.0
max_delay = (base_delay - 1) * 100.0
ep = ExponentialReconnectionPolicy(base_delay, max_delay, max_attempts=64)
schedule = ep.new_schedule()
for i in range(64):
Expand Down Expand Up @@ -1467,13 +1467,12 @@ def get_replicas(keyspace, packed_key):
# We don't allow randomness for ordering the replicas in RoundRobin
hfp._child_policy._child_policy._position = 0


mocked_query = Mock()
query_plan = hfp.make_query_plan("keyspace", mocked_query)
# First the not filtered replica, and then the rest of the allowed hosts ordered
query_plan = list(query_plan)
self.assertEqual(query_plan[0], Host(DefaultEndPoint("127.0.0.2"), SimpleConvictionPolicy))
self.assertEqual(set(query_plan[1:]),{Host(DefaultEndPoint("127.0.0.3"), SimpleConvictionPolicy),
self.assertEqual(set(query_plan[1:]), {Host(DefaultEndPoint("127.0.0.3"), SimpleConvictionPolicy),
Host(DefaultEndPoint("127.0.0.5"), SimpleConvictionPolicy)})

def test_create_whitelist(self):
Expand Down
Loading