Skip to content
2 changes: 1 addition & 1 deletion packages/google-cloud-ndb/google/cloud/ndb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from google.cloud.ndb import version

__version__ = version.__version__
__version__: str = version.__version__

from google.cloud.ndb.client import Client
from google.cloud.ndb.context import AutoBatcher
Expand Down
23 changes: 11 additions & 12 deletions packages/google-cloud-ndb/google/cloud/ndb/_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,18 +117,18 @@ def done_callback(self, cache_call):
"""
exception = cache_call.exception()
if exception:
for future in self.futures:
for future in self.futures: # type: ignore[attr-defined]
future.set_exception(exception)

else:
for future in self.futures:
for future in self.futures: # type: ignore[attr-defined]
future.set_result(None)

def make_call(self):
"""Make the actual call to the global cache. To be overridden."""
raise NotImplementedError

def future_info(self, key):
def future_info(self, key, value=None):
"""Generate info string for Future. To be overridden."""
raise NotImplementedError

Expand Down Expand Up @@ -279,7 +279,7 @@ def make_call(self):
"""Call :method:`GlobalCache.get`."""
return _global_cache().get(self.keys)

def future_info(self, key):
def future_info(self, key, value=None):
"""Generate info string for Future."""
return "GlobalCache.get({})".format(key)

Expand Down Expand Up @@ -373,7 +373,7 @@ def make_call(self):
"""Call :method:`GlobalCache.set`."""
return _global_cache().set(self.todo, expires=self.expires)

def future_info(self, key, value):
def future_info(self, key, value=None):
"""Generate info string for Future."""
return "GlobalCache.set({}, {})".format(key, value)

Expand Down Expand Up @@ -436,7 +436,7 @@ def make_call(self):
"""Call :method:`GlobalCache.set`."""
return _global_cache().set_if_not_exists(self.todo, expires=self.expires)

def future_info(self, key, value):
def future_info(self, key, value=None):
"""Generate info string for Future."""
return "GlobalCache.set_if_not_exists({}, {})".format(key, value)

Expand Down Expand Up @@ -482,7 +482,7 @@ def make_call(self):
"""Call :method:`GlobalCache.delete`."""
return _global_cache().delete(self.keys)

def future_info(self, key):
def future_info(self, key, value=None):
"""Generate info string for Future."""
return "GlobalCache.delete({})".format(key)

Expand Down Expand Up @@ -513,7 +513,7 @@ def make_call(self):
"""Call :method:`GlobalCache.watch`."""
return _global_cache().watch(self.todo)

def future_info(self, key, value):
def future_info(self, key, value=None):
"""Generate info string for Future."""
return "GlobalCache.watch({}, {})".format(key, value)

Expand Down Expand Up @@ -543,7 +543,7 @@ def make_call(self):
"""Call :method:`GlobalCache.unwatch`."""
return _global_cache().unwatch(self.keys)

def future_info(self, key):
def future_info(self, key, value=None):
"""Generate info string for Future."""
return "GlobalCache.unwatch({})".format(key)

Expand Down Expand Up @@ -580,7 +580,7 @@ def make_call(self):
"""Call :method:`GlobalCache.compare_and_swap`."""
return _global_cache().compare_and_swap(self.todo, expires=self.expires)

def future_info(self, key, value):
def future_info(self, key, value=None):
"""Generate info string for Future."""
return "GlobalCache.compare_and_swap({}, {})".format(key, value)

Expand Down Expand Up @@ -627,8 +627,7 @@ def global_lock_for_write(key):
tasklets.Future: Eventual result will be a lock value to be used later with
:func:`global_unlock`.
"""
lock = "." + str(uuid.uuid4())
lock = lock.encode("ascii")
lock = ("." + str(uuid.uuid4())).encode("ascii")
utils.logging_debug(log, "lock for write: {}", lock)

def new_value(old_value):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def lookup(key, options):
if use_global_cache and not key_locked:
if entity_pb is not _NOT_FOUND:
expires = context._global_cache_timeout(key, options)
serialized = entity_pb._pb.SerializeToString()
serialized = entity_pb._pb.SerializeToString() # type: ignore[attr-defined]
yield _cache.global_compare_and_swap(
cache_key, serialized, expires=expires
)
Expand Down
14 changes: 11 additions & 3 deletions packages/google-cloud-ndb/google/cloud/ndb/_datastore_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,8 @@ def has_next_async(self):
if self._batch is None:
yield self._next_batch() # First time

assert self._batch is not None
assert self._index is not None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you mean to have assert as part of this code path? Usually this is only used in tests

Copy link
Copy Markdown
Contributor Author

@chalmerlowe chalmerlowe Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked for ways to provide type narrowing for mypy and assert is one way it can be done.

It is admittedly controversial for some in production code.

I went ahead and switched the asserts into if clauses throughout.

if self._index < len(self._batch):
raise tasklets.Return(True)

Expand All @@ -359,7 +361,9 @@ def probably_has_next(self):
return (
self._batch is None # Haven't even started yet
or self._has_next_batch # There's another batch to fetch
or self._index < len(self._batch) # Not done with current batch
or (
self._index is not None and self._index < len(self._batch)
) # Not done with current batch
)

@tasklets.tasklet
Expand Down Expand Up @@ -421,6 +425,8 @@ def next(self):
self._cursor_before = None
raise StopIteration

assert self._batch is not None
assert self._index is not None
# Won't block
next_result = self._batch[self._index]
self._index += 1
Expand All @@ -446,7 +452,7 @@ def _peek(self):
batch = self._batch
index = self._index

if batch and index < len(batch):
if batch and index is not None and index < len(batch):
return batch[index]

raise KeyError(index)
Expand Down Expand Up @@ -554,6 +560,7 @@ def next(self):
if not self.has_next():
raise StopIteration()

assert self._next_result is not None
# Won't block
next_result = self._next_result
self._next_result = None
Expand Down Expand Up @@ -718,6 +725,7 @@ def next(self):
if not self.has_next():
raise StopIteration()

assert self._next_result is not None
# Won't block
next_result = self._next_result
self._next_result = None
Expand Down Expand Up @@ -949,7 +957,7 @@ def _query_to_protobuf(query):
filter_pb = ancestor_filter_pb

elif isinstance(filter_pb, query_pb2.CompositeFilter):
filter_pb.filters._pb.add(property_filter=ancestor_filter_pb._pb)
filter_pb.filters._pb.add(property_filter=ancestor_filter_pb._pb) # type: ignore[attr-defined]

else:
filter_pb = query_pb2.CompositeFilter(
Expand Down
11 changes: 7 additions & 4 deletions packages/google-cloud-ndb/google/cloud/ndb/_eventloop.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,15 @@ class EventLoop(object):
"""

def __init__(self):
self.current = collections.deque()
self.idlers = collections.deque()
self._init()

def _init(self):
self.current: collections.deque = collections.deque()
self.idlers: collections.deque = collections.deque()
self.inactive = 0
self.queue = []
self.rpcs = {}
self.rpc_results = queue.Queue()
self.rpc_results: queue.Queue = queue.Queue()

def clear(self):
"""Remove all pending events without running any."""
Expand All @@ -139,7 +142,7 @@ def clear(self):
utils.logging_debug(log, " queue = {}", queue)
if rpcs:
utils.logging_debug(log, " rpcs = {}", rpcs)
self.__init__()
self._init()
current.clear()
idlers.clear()
queue[:] = []
Expand Down
22 changes: 15 additions & 7 deletions packages/google-cloud-ndb/google/cloud/ndb/_gql.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
import re
import time
from typing import Any

from google.cloud.ndb import context as context_module
from google.cloud.ndb import exceptions
Expand Down Expand Up @@ -485,7 +486,7 @@ def _Literal(self):
a string, integer, floating point value, boolean or None).
"""

literal = None
literal: Any = None

if self._next_symbol < len(self._symbols):
try:
Expand Down Expand Up @@ -770,27 +771,34 @@ def _raise_cast_error(message):


def _time_function(values):
t_tuple: tuple[int, ...]
if len(values) == 1:
value = values[0]
if isinstance(value, str):
try:
time_tuple = time.strptime(value, "%H:%M:%S")
st = time.strptime(value, "%H:%M:%S")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace st with a more meaningful name

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gemini suggested this refactor to clean up this code and reduce the number of if statements

import datetime
import time

def _time_function(values):
    # 1. Normalize input into a tuple of integers
    if len(values) == 1:
        val = values[0]
        if isinstance(val, str):
            try:
                st = time.strptime(val, "%H:%M:%S")
                t_tuple = (st.tm_hour, st.tm_min, st.tm_sec)
            except ValueError as e:
                _raise_cast_error(f"Format error: {e}, {values}")
        elif isinstance(val, int):
            t_tuple = (val,)
        else:
            _raise_cast_error(f"Invalid type {type(val)} for time()")
    elif 1 < len(values) < 4:
        t_tuple = tuple(values)
    else:
        _raise_cast_error(f"Invalid number of arguments: {len(values)}")

    # 2. Convert tuple to datetime.time using unpacking
    try:
        return datetime.time(*t_tuple)
    except (ValueError, TypeError) as e:
        _raise_cast_error(f"Time conversion error: {e}, {values}")

Copy link
Copy Markdown
Contributor Author

@chalmerlowe chalmerlowe Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@parthea

These two code blocks are fundamentally the same with the exception of item # 2. Convert tuple to datetime.time using unpacking so my comments focus on that.

In your comment we assert that Gemini recommends this refactor to reduce the number of if statements.

Funnily enough:

Those if statements were also recommended by Gemini to enable greater type safety with mypy, because the datetime.time(*t_tuple) expression caused mypy errors.

google/cloud/ndb/_gql.py:793: error: Too many arguments for "time"  [call-arg]
google/cloud/ndb/_gql.py:793: error: Too many positional arguments for "time"  [misc]

Here is Gemini's rationale for the approach taken in this PR (this was the third change we made and Gemini felt it was the most important):

< BEGINNING of GEMINI's exposition >

  1. Exploding the "Splat" Operator (*t_tuple)

This is the most important change for Mypy:

Old: return datetime.time(*time_tuple)
The Problem: The * (splat) operator tells Python to unpack the tuple into arguments. However, datetime.time() accepts specific arguments: hour, minute, second, microsecond.
If time_tuple has a length of 1, it’s just the hour. If it's length 3, it's H, M, S.

Mypy cannot guarantee at compile-time that the length of the tuple matches what the function expects. It sees *time_tuple and worries you might be passing 10 arguments to a function that takes 4.

New:

if len(t_tuple) == 1:
    return datetime.time(t_tuple[0])
elif len(t_tuple) == 2:
    return datetime.time(t_tuple[0], t_tuple[1])
# ...

The Benefit: This is called Exhaustive Checking or Manual Unrolling.
By checking the len() explicitly, you are "narrowing" the type. Inside the if len(t_tuple) == 1 block, Mypy knows for a fact that t_tuple[0] exists and is the only argument.
This removes all ambiguity. Mypy can now prove the code is safe without needing a # type: ignore.

< END of GEMINI's exposition >

I am happy to resolve this in whatever way feels best to you, but I suspect that if I revert back to using the splat operator, we will have to go through another cycle of trying to find a way to keep mypy happy OR add one or more ignore pragmas.

except ValueError as error:
_raise_cast_error(
"Error during time conversion, {}, {}".format(error, values)
)
time_tuple = time_tuple[3:]
time_tuple = time_tuple[0:3]
t_tuple = (st.tm_hour, st.tm_min, st.tm_sec)
elif isinstance(value, int):
time_tuple = (value,)
t_tuple = (value,)
else:
_raise_cast_error("Invalid argument for time(), {}".format(value))
elif len(values) < 4:
time_tuple = tuple(values)
t_tuple = tuple(values)
else:
_raise_cast_error("Too many arguments for time(), {}".format(values))
try:
return datetime.time(*time_tuple)
if len(t_tuple) == 1:
return datetime.time(t_tuple[0])
elif len(t_tuple) == 2:
return datetime.time(t_tuple[0], t_tuple[1])
elif len(t_tuple) == 3:
return datetime.time(t_tuple[0], t_tuple[1], t_tuple[2])
else:
_raise_cast_error("Invalid arguments for time()")
except ValueError as error:
_raise_cast_error("Error during time conversion, {}, {}".format(error, values))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -389,11 +389,10 @@ class Property(ProtocolBuffer.ProtocolMessage):
24: "EMPTY_LIST",
}

@classmethod
def Meaning_Name(cls, x):
return cls._Meaning_NAMES.get(x, "")

Meaning_Name = classmethod(Meaning_Name)

has_meaning_ = 0
meaning_ = 0
has_meaning_uri_ = 0
Expand Down Expand Up @@ -526,7 +525,7 @@ class Path_Element(ProtocolBuffer.ProtocolMessage):
def type(self):
# Force legacy byte-str to be a str.
if type(self.type_) is bytes:
return self.type_.decode()
return self.type_.decode() # type: ignore[attr-defined]
return self.type_

def set_type(self, x):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ def MergePartialFromString(self, s):
d = Decoder(a, 0, len(a))
self.TryMerge(d)

def TryMerge(self, d):
raise NotImplementedError


class Decoder:
NUMERIC = 0
Expand Down
5 changes: 5 additions & 0 deletions packages/google-cloud-ndb/google/cloud/ndb/_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,16 @@
import logging

from google.cloud.ndb import exceptions
from typing import Any

log = logging.getLogger(__name__)


class Options(object):
max_memcache_items: Any
force_writes: Any
propagation: Any

__slots__ = (
# Supported
"retries",
Expand Down
2 changes: 1 addition & 1 deletion packages/google-cloud-ndb/google/cloud/ndb/_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self, future, info):
self.future = future
self.info = info
self.start_time = time.time()
self.elapsed_time = 0
self.elapsed_time = 0.0

def record_time(future):
self.elapsed_time = time.time() - self.start_time
Expand Down
6 changes: 3 additions & 3 deletions packages/google-cloud-ndb/google/cloud/ndb/_retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def wraps_safely(obj, attr_names=functools.WRAPPER_ASSIGNMENTS):
are not copied to the wrappers and thus cause attribute errors. This
wrapper prevents that problem."""
return functools.wraps(
obj, assigned=(name for name in attr_names if hasattr(obj, name))
obj, assigned=tuple(name for name in attr_names if hasattr(obj, name))
)


Expand Down Expand Up @@ -84,7 +84,7 @@ def retry_wrapper(*args, **kwargs):
error = e
except BaseException as e:
# `e` is removed from locals at end of block
error = e # See: https://goo.gl/5J8BMK
error = e # type: ignore[assignment] # See: https://goo.gl/5J8BMK

if not is_transient_error(error):
# If we are in an inner retry block, use special nested
Expand All @@ -107,7 +107,7 @@ def retry_wrapper(*args, **kwargs):

# Unknown errors really want to show up as None, so manually set the error.
if isinstance(error, core_exceptions.Unknown):
error = "google.api_core.exceptions.Unknown"
error = "google.api_core.exceptions.Unknown" # type: ignore[assignment]

raise core_exceptions.RetryError(
"Maximum number of {} retries exceeded while calling {}".format(
Expand Down
4 changes: 2 additions & 2 deletions packages/google-cloud-ndb/google/cloud/ndb/_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,8 @@ def _transaction_async(context, callback, read_only=False):
transaction_id = yield _datastore_api.begin_transaction(read_only, retries=0)
utils.logging_debug(log, "Transaction Id: {}", transaction_id)

on_commit_callbacks = []
transaction_complete_callbacks = []
on_commit_callbacks: list = []
transaction_complete_callbacks: list = []
tx_context = context.new(
transaction=transaction_id,
on_commit_callbacks=on_commit_callbacks,
Expand Down
17 changes: 12 additions & 5 deletions packages/google-cloud-ndb/google/cloud/ndb/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import os
import threading
import uuid
from typing import Any, cast

from google.cloud.ndb import _eventloop
from google.cloud.ndb import exceptions
Expand Down Expand Up @@ -254,6 +255,11 @@ class _Context(_ContextTuple):
client (client.Client): The NDB client for this context.
"""

cache_policy: Any
global_cache_policy: Any
global_cache_timeout_policy: Any
datastore_policy: Any

def __new__(
cls,
client,
Expand Down Expand Up @@ -313,11 +319,12 @@ def __new__(
legacy_data=legacy_data,
)

context.set_cache_policy(cache_policy)
context.set_global_cache_policy(global_cache_policy)
context.set_global_cache_timeout_policy(global_cache_timeout_policy)
context.set_datastore_policy(datastore_policy)
context.set_retry_state(retry)
ctx = cast(Any, context)
Comment thread
chalmerlowe marked this conversation as resolved.
ctx.set_cache_policy(cache_policy)
ctx.set_global_cache_policy(global_cache_policy)
ctx.set_global_cache_timeout_policy(global_cache_timeout_policy)
ctx.set_datastore_policy(datastore_policy)
ctx.set_retry_state(retry)

return context

Expand Down
Loading
Loading