Skip to content

Commit 86202d4

Browse files
Merge pull request #1690 from codeflash-ai/fix/comparator-itertools-count
fix: handle itertools types in comparator with Python 3.9-3.14 support
2 parents 4843748 + 3a33fe4 commit 86202d4

2 files changed

Lines changed: 357 additions & 0 deletions

File tree

codeflash/verification/comparator.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
import datetime
44
import decimal
55
import enum
6+
import itertools
67
import math
78
import re
89
import types
10+
import warnings
911
import weakref
1012
from collections import ChainMap, OrderedDict, deque
1113
from importlib.util import find_spec
@@ -528,6 +530,55 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
528530
)
529531
return comparator(orig_dict, new_dict, superset_obj)
530532

533+
# Handle itertools infinite iterators
534+
if isinstance(orig, itertools.count):
535+
# repr reliably reflects internal state, e.g. "count(5)" or "count(5, 2)"
536+
return repr(orig) == repr(new)
537+
538+
if isinstance(orig, itertools.repeat):
539+
# repr reliably reflects internal state, e.g. "repeat(5)" or "repeat(5, 3)"
540+
return repr(orig) == repr(new)
541+
542+
if isinstance(orig, itertools.cycle):
543+
# cycle has no useful repr and no public attributes; use __reduce__ to extract state.
544+
# __reduce__ returns (cls, (remaining_iter,), (saved_items, first_pass_done)).
545+
# NOTE: consuming the remaining_iter is destructive to the cycle object, but this is
546+
# acceptable since the comparator is the final consumer of captured return values.
547+
# NOTE: __reduce__ on itertools.cycle was removed in Python 3.14.
548+
try:
549+
with warnings.catch_warnings():
550+
warnings.simplefilter("ignore", DeprecationWarning)
551+
orig_reduce = orig.__reduce__()
552+
new_reduce = new.__reduce__()
553+
orig_remaining = list(orig_reduce[1][0])
554+
new_remaining = list(new_reduce[1][0])
555+
orig_saved, orig_started = orig_reduce[2]
556+
new_saved, new_started = new_reduce[2]
557+
if orig_started != new_started:
558+
return False
559+
return comparator(orig_remaining, new_remaining, superset_obj) and comparator(
560+
orig_saved, new_saved, superset_obj
561+
)
562+
except TypeError:
563+
# Python 3.14+: __reduce__ removed. Fall back to consuming elements from both
564+
# cycles and comparing. Since the comparator is the final consumer, this is safe.
565+
sample_size = 200
566+
orig_sample = [next(orig) for _ in range(sample_size)]
567+
new_sample = [next(new) for _ in range(sample_size)]
568+
return comparator(orig_sample, new_sample, superset_obj)
569+
570+
# Handle remaining itertools types (chain, islice, starmap, product, permutations, etc.)
571+
# by materializing into lists. count/repeat/cycle are already handled above.
572+
# NOTE: materializing is destructive (consumes the iterator) and will hang on infinite input,
573+
# but the three infinite itertools types are already handled above.
574+
if type(orig).__module__ == "itertools":
575+
if isinstance(orig, itertools.groupby):
576+
# groupby yields (key, group_iterator) — materialize groups too
577+
orig_groups = [(k, list(g)) for k, g in orig]
578+
new_groups = [(k, list(g)) for k, g in new]
579+
return comparator(orig_groups, new_groups, superset_obj)
580+
return comparator(list(orig), list(new), superset_obj)
581+
531582
# re.Pattern can be made better by DFA Minimization and then comparing
532583
if isinstance(
533584
orig, (datetime.datetime, datetime.date, datetime.timedelta, datetime.time, datetime.timezone, re.Pattern)

tests/test_comparator.py

Lines changed: 306 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,312 @@ class Color4(IntFlag):
417417
assert not comparator(id1, id3)
418418

419419

420+
def test_itertools_count() -> None:
421+
import itertools
422+
423+
# Equal: same start and step (default step=1)
424+
assert comparator(itertools.count(0), itertools.count(0))
425+
assert comparator(itertools.count(5), itertools.count(5))
426+
assert comparator(itertools.count(0, 1), itertools.count(0, 1))
427+
assert comparator(itertools.count(10, 3), itertools.count(10, 3))
428+
429+
# Equal: negative start and step
430+
assert comparator(itertools.count(-5, -2), itertools.count(-5, -2))
431+
432+
# Equal: float start and step
433+
assert comparator(itertools.count(0.5, 0.1), itertools.count(0.5, 0.1))
434+
435+
# Not equal: different start
436+
assert not comparator(itertools.count(0), itertools.count(1))
437+
assert not comparator(itertools.count(5), itertools.count(10))
438+
439+
# Not equal: different step
440+
assert not comparator(itertools.count(0, 1), itertools.count(0, 2))
441+
assert not comparator(itertools.count(0, 1), itertools.count(0, -1))
442+
443+
# Not equal: different type
444+
assert not comparator(itertools.count(0), 0)
445+
assert not comparator(itertools.count(0), [0, 1, 2])
446+
447+
# Equal after partial consumption (both advanced to the same state)
448+
a = itertools.count(0)
449+
b = itertools.count(0)
450+
next(a)
451+
next(b)
452+
assert comparator(a, b)
453+
454+
# Not equal after different consumption
455+
a = itertools.count(0)
456+
b = itertools.count(0)
457+
next(a)
458+
assert not comparator(a, b)
459+
460+
# Works inside containers
461+
assert comparator([itertools.count(0)], [itertools.count(0)])
462+
assert comparator({"key": itertools.count(5, 2)}, {"key": itertools.count(5, 2)})
463+
assert not comparator([itertools.count(0)], [itertools.count(1)])
464+
465+
466+
def test_itertools_repeat() -> None:
467+
import itertools
468+
469+
# Equal: infinite repeat
470+
assert comparator(itertools.repeat(5), itertools.repeat(5))
471+
assert comparator(itertools.repeat("hello"), itertools.repeat("hello"))
472+
473+
# Equal: bounded repeat
474+
assert comparator(itertools.repeat(5, 3), itertools.repeat(5, 3))
475+
assert comparator(itertools.repeat(None, 10), itertools.repeat(None, 10))
476+
477+
# Not equal: different value
478+
assert not comparator(itertools.repeat(5), itertools.repeat(6))
479+
assert not comparator(itertools.repeat(5, 3), itertools.repeat(6, 3))
480+
481+
# Not equal: different count
482+
assert not comparator(itertools.repeat(5, 3), itertools.repeat(5, 4))
483+
484+
# Not equal: bounded vs infinite
485+
assert not comparator(itertools.repeat(5), itertools.repeat(5, 3))
486+
487+
# Not equal: different type
488+
assert not comparator(itertools.repeat(5), 5)
489+
assert not comparator(itertools.repeat(5), [5])
490+
491+
# Equal after partial consumption
492+
a = itertools.repeat(5, 5)
493+
b = itertools.repeat(5, 5)
494+
next(a)
495+
next(b)
496+
assert comparator(a, b)
497+
498+
# Not equal after different consumption
499+
a = itertools.repeat(5, 5)
500+
b = itertools.repeat(5, 5)
501+
next(a)
502+
assert not comparator(a, b)
503+
504+
# Works inside containers
505+
assert comparator([itertools.repeat(5, 3)], [itertools.repeat(5, 3)])
506+
assert not comparator([itertools.repeat(5, 3)], [itertools.repeat(5, 4)])
507+
508+
509+
def test_itertools_cycle() -> None:
510+
import itertools
511+
512+
# Equal: same sequence
513+
assert comparator(itertools.cycle([1, 2, 3]), itertools.cycle([1, 2, 3]))
514+
assert comparator(itertools.cycle("abc"), itertools.cycle("abc"))
515+
516+
# Not equal: different sequence
517+
assert not comparator(itertools.cycle([1, 2, 3]), itertools.cycle([1, 2, 4]))
518+
assert not comparator(itertools.cycle([1, 2, 3]), itertools.cycle([1, 2]))
519+
520+
# Not equal: different type
521+
assert not comparator(itertools.cycle([1, 2, 3]), [1, 2, 3])
522+
523+
# Equal after same partial consumption
524+
a = itertools.cycle([1, 2, 3])
525+
b = itertools.cycle([1, 2, 3])
526+
next(a)
527+
next(b)
528+
assert comparator(a, b)
529+
530+
# Not equal after different consumption
531+
a = itertools.cycle([1, 2, 3])
532+
b = itertools.cycle([1, 2, 3])
533+
next(a)
534+
assert not comparator(a, b)
535+
536+
# Equal after consuming a full cycle
537+
a = itertools.cycle([1, 2, 3])
538+
b = itertools.cycle([1, 2, 3])
539+
for _ in range(3):
540+
next(a)
541+
next(b)
542+
assert comparator(a, b)
543+
544+
# Equal at same position across different full-cycle counts
545+
a = itertools.cycle([1, 2, 3])
546+
b = itertools.cycle([1, 2, 3])
547+
for _ in range(4):
548+
next(a)
549+
for _ in range(7):
550+
next(b)
551+
# Both at position 1 within the cycle (4%3 == 7%3 == 1)
552+
assert comparator(a, b)
553+
554+
# Works inside containers
555+
assert comparator([itertools.cycle([1, 2])], [itertools.cycle([1, 2])])
556+
assert not comparator([itertools.cycle([1, 2])], [itertools.cycle([1, 3])])
557+
558+
559+
def test_itertools_chain() -> None:
560+
import itertools
561+
562+
assert comparator(itertools.chain([1, 2], [3, 4]), itertools.chain([1, 2], [3, 4]))
563+
assert not comparator(itertools.chain([1, 2], [3, 4]), itertools.chain([1, 2], [3, 5]))
564+
assert comparator(itertools.chain.from_iterable([[1, 2], [3]]), itertools.chain.from_iterable([[1, 2], [3]]))
565+
assert comparator(itertools.chain(), itertools.chain())
566+
assert not comparator(itertools.chain([1]), itertools.chain([1, 2]))
567+
568+
569+
def test_itertools_islice() -> None:
570+
import itertools
571+
572+
assert comparator(itertools.islice(range(10), 5), itertools.islice(range(10), 5))
573+
assert not comparator(itertools.islice(range(10), 5), itertools.islice(range(10), 6))
574+
assert comparator(itertools.islice(range(10), 2, 5), itertools.islice(range(10), 2, 5))
575+
assert not comparator(itertools.islice(range(10), 2, 5), itertools.islice(range(10), 2, 6))
576+
577+
578+
def test_itertools_product() -> None:
579+
import itertools
580+
581+
assert comparator(itertools.product("AB", repeat=2), itertools.product("AB", repeat=2))
582+
assert not comparator(itertools.product("AB", repeat=2), itertools.product("AC", repeat=2))
583+
assert comparator(itertools.product([1, 2], [3, 4]), itertools.product([1, 2], [3, 4]))
584+
assert not comparator(itertools.product([1, 2], [3, 4]), itertools.product([1, 2], [3, 5]))
585+
586+
587+
def test_itertools_permutations_combinations() -> None:
588+
import itertools
589+
590+
assert comparator(itertools.permutations("ABC", 2), itertools.permutations("ABC", 2))
591+
assert not comparator(itertools.permutations("ABC", 2), itertools.permutations("ABD", 2))
592+
assert comparator(itertools.combinations("ABCD", 2), itertools.combinations("ABCD", 2))
593+
assert not comparator(itertools.combinations("ABCD", 2), itertools.combinations("ABCD", 3))
594+
assert comparator(
595+
itertools.combinations_with_replacement("ABC", 2),
596+
itertools.combinations_with_replacement("ABC", 2),
597+
)
598+
assert not comparator(
599+
itertools.combinations_with_replacement("ABC", 2),
600+
itertools.combinations_with_replacement("ABD", 2),
601+
)
602+
603+
604+
def test_itertools_accumulate() -> None:
605+
import itertools
606+
607+
assert comparator(itertools.accumulate([1, 2, 3, 4]), itertools.accumulate([1, 2, 3, 4]))
608+
assert not comparator(itertools.accumulate([1, 2, 3, 4]), itertools.accumulate([1, 2, 3, 5]))
609+
assert comparator(itertools.accumulate([1, 2, 3], initial=10), itertools.accumulate([1, 2, 3], initial=10))
610+
assert not comparator(itertools.accumulate([1, 2, 3], initial=10), itertools.accumulate([1, 2, 3], initial=0))
611+
612+
613+
def test_itertools_filtering() -> None:
614+
import itertools
615+
616+
# compress
617+
assert comparator(
618+
itertools.compress("ABCDEF", [1, 0, 1, 0, 1, 1]),
619+
itertools.compress("ABCDEF", [1, 0, 1, 0, 1, 1]),
620+
)
621+
assert not comparator(
622+
itertools.compress("ABCDEF", [1, 0, 1, 0, 1, 1]),
623+
itertools.compress("ABCDEF", [1, 1, 1, 0, 1, 1]),
624+
)
625+
626+
# dropwhile
627+
assert comparator(
628+
itertools.dropwhile(lambda x: x < 5, [1, 4, 6, 4, 1]),
629+
itertools.dropwhile(lambda x: x < 5, [1, 4, 6, 4, 1]),
630+
)
631+
assert not comparator(
632+
itertools.dropwhile(lambda x: x < 5, [1, 4, 6, 4, 1]),
633+
itertools.dropwhile(lambda x: x < 5, [1, 4, 7, 4, 1]),
634+
)
635+
636+
# takewhile
637+
assert comparator(
638+
itertools.takewhile(lambda x: x < 5, [1, 4, 6, 4, 1]),
639+
itertools.takewhile(lambda x: x < 5, [1, 4, 6, 4, 1]),
640+
)
641+
assert not comparator(
642+
itertools.takewhile(lambda x: x < 5, [1, 4, 6, 4, 1]),
643+
itertools.takewhile(lambda x: x < 5, [1, 3, 6, 4, 1]),
644+
)
645+
646+
# filterfalse
647+
assert comparator(
648+
itertools.filterfalse(lambda x: x % 2, range(10)),
649+
itertools.filterfalse(lambda x: x % 2, range(10)),
650+
)
651+
652+
653+
def test_itertools_starmap() -> None:
654+
import itertools
655+
656+
assert comparator(
657+
itertools.starmap(pow, [(2, 3), (3, 2), (10, 0)]),
658+
itertools.starmap(pow, [(2, 3), (3, 2), (10, 0)]),
659+
)
660+
assert not comparator(
661+
itertools.starmap(pow, [(2, 3), (3, 2)]),
662+
itertools.starmap(pow, [(2, 3), (3, 3)]),
663+
)
664+
665+
666+
def test_itertools_zip_longest() -> None:
667+
import itertools
668+
669+
assert comparator(
670+
itertools.zip_longest("AB", "xyz", fillvalue="-"),
671+
itertools.zip_longest("AB", "xyz", fillvalue="-"),
672+
)
673+
assert not comparator(
674+
itertools.zip_longest("AB", "xyz", fillvalue="-"),
675+
itertools.zip_longest("AB", "xyz", fillvalue="*"),
676+
)
677+
678+
679+
def test_itertools_groupby() -> None:
680+
import itertools
681+
682+
assert comparator(itertools.groupby("AAABBBCC"), itertools.groupby("AAABBBCC"))
683+
assert not comparator(itertools.groupby("AAABBBCC"), itertools.groupby("AAABBCC"))
684+
assert comparator(itertools.groupby([]), itertools.groupby([]))
685+
686+
# With key function
687+
assert comparator(
688+
itertools.groupby([1, 1, 2, 2, 3], key=lambda x: x),
689+
itertools.groupby([1, 1, 2, 2, 3], key=lambda x: x),
690+
)
691+
692+
693+
@pytest.mark.skipif(sys.version_info < (3, 10), reason="itertools.pairwise requires Python 3.10+")
694+
def test_itertools_pairwise() -> None:
695+
import itertools
696+
697+
assert comparator(itertools.pairwise([1, 2, 3, 4]), itertools.pairwise([1, 2, 3, 4]))
698+
assert not comparator(itertools.pairwise([1, 2, 3, 4]), itertools.pairwise([1, 2, 3, 5]))
699+
700+
701+
@pytest.mark.skipif(sys.version_info < (3, 12), reason="itertools.batched requires Python 3.12+")
702+
def test_itertools_batched() -> None:
703+
import itertools
704+
705+
assert comparator(itertools.batched("ABCDEFG", 3), itertools.batched("ABCDEFG", 3))
706+
assert not comparator(itertools.batched("ABCDEFG", 3), itertools.batched("ABCDEFG", 2))
707+
708+
709+
def test_itertools_in_containers() -> None:
710+
import itertools
711+
712+
# Itertools objects nested in dicts/lists
713+
assert comparator(
714+
{"a": itertools.chain([1], [2]), "b": itertools.islice(range(5), 3)},
715+
{"a": itertools.chain([1], [2]), "b": itertools.islice(range(5), 3)},
716+
)
717+
assert not comparator(
718+
[itertools.product("AB", repeat=2)],
719+
[itertools.product("AC", repeat=2)],
720+
)
721+
722+
# Different itertools types should not match
723+
assert not comparator(itertools.chain([1, 2]), itertools.islice([1, 2], 2))
724+
725+
420726
def test_numpy():
421727
try:
422728
import numpy as np

0 commit comments

Comments
 (0)