Skip to content

Commit eeda6c2

Browse files
aseembits93claude
andcommitted
fix: handle all remaining itertools types in comparator
Add a catch-all handler for itertools iterators (chain, islice, product, permutations, combinations, starmap, accumulate, compress, dropwhile, takewhile, filterfalse, zip_longest, groupby, pairwise, batched, tee). Uses module check (type.__module__ == "itertools") so it automatically covers any itertools type without version-specific enumeration. groupby gets special handling to also materialize its group iterators. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 456a188 commit eeda6c2

2 files changed

Lines changed: 173 additions & 0 deletions

File tree

codeflash/verification/comparator.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,18 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
559559
orig_saved, new_saved, superset_obj
560560
)
561561

562+
# Handle remaining itertools types (chain, islice, starmap, product, permutations, etc.)
563+
# by materializing into lists. count/repeat/cycle are already handled above.
564+
# NOTE: materializing is destructive (consumes the iterator) and will hang on infinite input,
565+
# but the three infinite itertools types are already handled above.
566+
if type(orig).__module__ == "itertools":
567+
if isinstance(orig, itertools.groupby):
568+
# groupby yields (key, group_iterator) — materialize groups too
569+
orig_groups = [(k, list(g)) for k, g in orig]
570+
new_groups = [(k, list(g)) for k, g in new]
571+
return comparator(orig_groups, new_groups, superset_obj)
572+
return comparator(list(orig), list(new), superset_obj)
573+
562574
# re.Pattern can be made better by DFA Minimization and then comparing
563575
if isinstance(
564576
orig, (datetime.datetime, datetime.date, datetime.timedelta, datetime.time, datetime.timezone, re.Pattern)

tests/test_comparator.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,167 @@ def test_itertools_cycle() -> None:
556556
assert not comparator([itertools.cycle([1, 2])], [itertools.cycle([1, 3])])
557557

558558

559+
def test_itertools_chain() -> None:
560+
import itertools
561+
562+
assert comparator(itertools.chain([1, 2], [3, 4]), itertools.chain([1, 2], [3, 4]))
563+
assert not comparator(itertools.chain([1, 2], [3, 4]), itertools.chain([1, 2], [3, 5]))
564+
assert comparator(itertools.chain.from_iterable([[1, 2], [3]]), itertools.chain.from_iterable([[1, 2], [3]]))
565+
assert comparator(itertools.chain(), itertools.chain())
566+
assert not comparator(itertools.chain([1]), itertools.chain([1, 2]))
567+
568+
569+
def test_itertools_islice() -> None:
570+
import itertools
571+
572+
assert comparator(itertools.islice(range(10), 5), itertools.islice(range(10), 5))
573+
assert not comparator(itertools.islice(range(10), 5), itertools.islice(range(10), 6))
574+
assert comparator(itertools.islice(range(10), 2, 5), itertools.islice(range(10), 2, 5))
575+
assert not comparator(itertools.islice(range(10), 2, 5), itertools.islice(range(10), 2, 6))
576+
577+
578+
def test_itertools_product() -> None:
579+
import itertools
580+
581+
assert comparator(itertools.product("AB", repeat=2), itertools.product("AB", repeat=2))
582+
assert not comparator(itertools.product("AB", repeat=2), itertools.product("AC", repeat=2))
583+
assert comparator(itertools.product([1, 2], [3, 4]), itertools.product([1, 2], [3, 4]))
584+
assert not comparator(itertools.product([1, 2], [3, 4]), itertools.product([1, 2], [3, 5]))
585+
586+
587+
def test_itertools_permutations_combinations() -> None:
588+
import itertools
589+
590+
assert comparator(itertools.permutations("ABC", 2), itertools.permutations("ABC", 2))
591+
assert not comparator(itertools.permutations("ABC", 2), itertools.permutations("ABD", 2))
592+
assert comparator(itertools.combinations("ABCD", 2), itertools.combinations("ABCD", 2))
593+
assert not comparator(itertools.combinations("ABCD", 2), itertools.combinations("ABCD", 3))
594+
assert comparator(
595+
itertools.combinations_with_replacement("ABC", 2),
596+
itertools.combinations_with_replacement("ABC", 2),
597+
)
598+
assert not comparator(
599+
itertools.combinations_with_replacement("ABC", 2),
600+
itertools.combinations_with_replacement("ABD", 2),
601+
)
602+
603+
604+
def test_itertools_accumulate() -> None:
605+
import itertools
606+
607+
assert comparator(itertools.accumulate([1, 2, 3, 4]), itertools.accumulate([1, 2, 3, 4]))
608+
assert not comparator(itertools.accumulate([1, 2, 3, 4]), itertools.accumulate([1, 2, 3, 5]))
609+
assert comparator(itertools.accumulate([1, 2, 3], initial=10), itertools.accumulate([1, 2, 3], initial=10))
610+
assert not comparator(itertools.accumulate([1, 2, 3], initial=10), itertools.accumulate([1, 2, 3], initial=0))
611+
612+
613+
def test_itertools_filtering() -> None:
614+
import itertools
615+
616+
# compress
617+
assert comparator(
618+
itertools.compress("ABCDEF", [1, 0, 1, 0, 1, 1]),
619+
itertools.compress("ABCDEF", [1, 0, 1, 0, 1, 1]),
620+
)
621+
assert not comparator(
622+
itertools.compress("ABCDEF", [1, 0, 1, 0, 1, 1]),
623+
itertools.compress("ABCDEF", [1, 1, 1, 0, 1, 1]),
624+
)
625+
626+
# dropwhile
627+
assert comparator(
628+
itertools.dropwhile(lambda x: x < 5, [1, 4, 6, 4, 1]),
629+
itertools.dropwhile(lambda x: x < 5, [1, 4, 6, 4, 1]),
630+
)
631+
assert not comparator(
632+
itertools.dropwhile(lambda x: x < 5, [1, 4, 6, 4, 1]),
633+
itertools.dropwhile(lambda x: x < 5, [1, 4, 7, 4, 1]),
634+
)
635+
636+
# takewhile
637+
assert comparator(
638+
itertools.takewhile(lambda x: x < 5, [1, 4, 6, 4, 1]),
639+
itertools.takewhile(lambda x: x < 5, [1, 4, 6, 4, 1]),
640+
)
641+
assert not comparator(
642+
itertools.takewhile(lambda x: x < 5, [1, 4, 6, 4, 1]),
643+
itertools.takewhile(lambda x: x < 5, [1, 3, 6, 4, 1]),
644+
)
645+
646+
# filterfalse
647+
assert comparator(
648+
itertools.filterfalse(lambda x: x % 2, range(10)),
649+
itertools.filterfalse(lambda x: x % 2, range(10)),
650+
)
651+
652+
653+
def test_itertools_starmap() -> None:
654+
import itertools
655+
656+
assert comparator(
657+
itertools.starmap(pow, [(2, 3), (3, 2), (10, 0)]),
658+
itertools.starmap(pow, [(2, 3), (3, 2), (10, 0)]),
659+
)
660+
assert not comparator(
661+
itertools.starmap(pow, [(2, 3), (3, 2)]),
662+
itertools.starmap(pow, [(2, 3), (3, 3)]),
663+
)
664+
665+
666+
def test_itertools_zip_longest() -> None:
667+
import itertools
668+
669+
assert comparator(
670+
itertools.zip_longest("AB", "xyz", fillvalue="-"),
671+
itertools.zip_longest("AB", "xyz", fillvalue="-"),
672+
)
673+
assert not comparator(
674+
itertools.zip_longest("AB", "xyz", fillvalue="-"),
675+
itertools.zip_longest("AB", "xyz", fillvalue="*"),
676+
)
677+
678+
679+
def test_itertools_groupby() -> None:
680+
import itertools
681+
682+
assert comparator(itertools.groupby("AAABBBCC"), itertools.groupby("AAABBBCC"))
683+
assert not comparator(itertools.groupby("AAABBBCC"), itertools.groupby("AAABBCC"))
684+
assert comparator(itertools.groupby([]), itertools.groupby([]))
685+
686+
# With key function
687+
assert comparator(
688+
itertools.groupby([1, 1, 2, 2, 3], key=lambda x: x),
689+
itertools.groupby([1, 1, 2, 2, 3], key=lambda x: x),
690+
)
691+
692+
693+
def test_itertools_pairwise_batched() -> None:
694+
import itertools
695+
696+
assert comparator(itertools.pairwise([1, 2, 3, 4]), itertools.pairwise([1, 2, 3, 4]))
697+
assert not comparator(itertools.pairwise([1, 2, 3, 4]), itertools.pairwise([1, 2, 3, 5]))
698+
699+
assert comparator(itertools.batched("ABCDEFG", 3), itertools.batched("ABCDEFG", 3))
700+
assert not comparator(itertools.batched("ABCDEFG", 3), itertools.batched("ABCDEFG", 2))
701+
702+
703+
def test_itertools_in_containers() -> None:
704+
import itertools
705+
706+
# Itertools objects nested in dicts/lists
707+
assert comparator(
708+
{"a": itertools.chain([1], [2]), "b": itertools.islice(range(5), 3)},
709+
{"a": itertools.chain([1], [2]), "b": itertools.islice(range(5), 3)},
710+
)
711+
assert not comparator(
712+
[itertools.product("AB", repeat=2)],
713+
[itertools.product("AC", repeat=2)],
714+
)
715+
716+
# Different itertools types should not match
717+
assert not comparator(itertools.chain([1, 2]), itertools.islice([1, 2], 2))
718+
719+
559720
def test_numpy():
560721
try:
561722
import numpy as np

0 commit comments

Comments
 (0)