Skip to content

Commit 530f5c3

Browse files
committed
[GR-42218] Fix dict and set operation result classes.
PullRequest: graalpython/4607
2 parents ebc2608 + 4856fe7 commit 530f5c3

3 files changed

Lines changed: 88 additions & 3 deletions

File tree

graalpython/com.oracle.graal.python.test/src/tests/test_dict.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
# SOFTWARE.
3939

4040
import unittest, sys
41+
from collections import defaultdict
4142

4243
graalpy_only = unittest.skipUnless(sys.implementation.name == "graalpy", "GraalPy-specific dict storage test")
4344

@@ -577,6 +578,38 @@ def test_copy():
577578
assert set(d1.keys()) == {'a', 'b', 'c'}
578579

579580

581+
def test_defaultdict_operations_subclass_preserve_type():
582+
class DefaultDictSubclass(defaultdict):
583+
pass
584+
585+
d = DefaultDictSubclass(int, a=1)
586+
copied = d.copy()
587+
merged = d | {"b": 2}
588+
rmerged = {"b": 2} | d
589+
590+
assert type(copied) is DefaultDictSubclass
591+
assert copied.default_factory is int
592+
assert dict(copied) == {"a": 1}
593+
assert type(merged) is DefaultDictSubclass
594+
assert merged.default_factory is int
595+
assert dict(merged) == {"a": 1, "b": 2}
596+
assert type(rmerged) is DefaultDictSubclass
597+
assert rmerged.default_factory is int
598+
assert dict(rmerged) == {"b": 2, "a": 1}
599+
600+
601+
def test_dict_operations_return_builtin_dict_for_subclass():
602+
class DictSubclass(dict):
603+
pass
604+
605+
d = DictSubclass(a=1)
606+
other = {"b": 2}
607+
608+
assert type(d.copy()) is dict
609+
assert type(d | other) is dict
610+
assert type(other | d) is dict
611+
612+
580613
def test_keywords():
581614
def modifying(**kwargs):
582615
kwargs["a"] = 10

graalpython/com.oracle.graal.python.test/src/tests/test_set.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -686,3 +686,40 @@ def test_set_iterator_reduce():
686686
it = s.__iter__()
687687
it.__reduce__()
688688
assert [i for i in it] == [1, 2, 3]
689+
690+
691+
def test_set_operations_return_builtin_set_for_subclass():
692+
class SetSubclass(set):
693+
pass
694+
695+
s = SetSubclass([1, 2])
696+
other = {2, 3}
697+
698+
assert type(s.copy()) is set
699+
assert type(s | other) is set
700+
assert type(other | s) is set
701+
assert type(s & other) is set
702+
assert type(s - other) is set
703+
assert type(s ^ other) is set
704+
assert type(s.union(other)) is set
705+
assert type(s.intersection(other)) is set
706+
assert type(s.difference(other)) is set
707+
assert type(s.symmetric_difference(other)) is set
708+
709+
710+
def test_frozenset_operations_return_builtin_frozenset_for_subclass():
711+
class FrozenSetSubclass(frozenset):
712+
pass
713+
714+
f = FrozenSetSubclass([1, 2])
715+
other = {2, 3}
716+
717+
assert type(f.copy()) is frozenset
718+
assert type(f | other) is frozenset
719+
assert type(f & other) is frozenset
720+
assert type(f - other) is frozenset
721+
assert type(f ^ other) is frozenset
722+
assert type(f.union(other)) is frozenset
723+
assert type(f.intersection(other)) is frozenset
724+
assert type(f.difference(other)) is frozenset
725+
assert type(f.symmetric_difference(other)) is frozenset

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/dict/DefaultDictBuiltins.java

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2021, 2025, Oracle and/or its affiliates. All rights reserved.
2+
* Copyright (c) 2021, 2026, Oracle and/or its affiliates. All rights reserved.
33
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
44
*
55
* The Universal Permissive License (UPL), Version 1.0
@@ -78,6 +78,7 @@
7878
import com.oracle.graal.python.nodes.object.GetClassNode;
7979
import com.oracle.graal.python.runtime.object.PFactory;
8080
import com.oracle.graal.python.util.PythonUtils;
81+
import com.oracle.truffle.api.HostCompilerDirectives.InliningCutoff;
8182
import com.oracle.truffle.api.dsl.Bind;
8283
import com.oracle.truffle.api.dsl.Cached;
8384
import com.oracle.truffle.api.dsl.Fallback;
@@ -149,13 +150,27 @@ static Object reduce(VirtualFrame frame, PDefaultDict self,
149150
@Builtin(name = "copy", minNumOfPositionalArgs = 1)
150151
@GenerateNodeFactory
151152
public abstract static class CopyNode extends PythonUnaryBuiltinNode {
152-
@Specialization
153-
static PDefaultDict copy(@SuppressWarnings("unused") VirtualFrame frame, PDefaultDict self,
153+
@Specialization(guards = "isBuiltinDefaultDict(self)")
154+
static PDefaultDict copyBuiltin(PDefaultDict self,
154155
@Bind Node inliningTarget,
155156
@Cached HashingStorageCopy copyNode,
156157
@Bind PythonLanguage language) {
157158
return PFactory.createDefaultDict(language, self.getDefaultFactory(), copyNode.execute(inliningTarget, self.getDictStorage()));
158159
}
160+
161+
@Fallback
162+
@InliningCutoff
163+
static Object copyGeneric(VirtualFrame frame, Object self,
164+
@Bind Node inliningTarget,
165+
@Cached GetClassNode getClassNode,
166+
@Cached CallNode callNode) {
167+
PDefaultDict defaultDict = (PDefaultDict) self;
168+
return callNode.execute(frame, getClassNode.execute(inliningTarget, defaultDict), defaultDict.getDefaultFactory(), defaultDict);
169+
}
170+
171+
static boolean isBuiltinDefaultDict(PDefaultDict self) {
172+
return self.getPythonClass() == PythonBuiltinClassType.PDefaultDict;
173+
}
159174
}
160175

161176
@Builtin(name = J___MISSING__, minNumOfPositionalArgs = 2)

0 commit comments

Comments
 (0)