Skip to content

Commit 8ce3987

Browse files
committed
Update atomdict
1 parent b2b0595 commit 8ce3987

3 files changed

Lines changed: 60 additions & 29 deletions

File tree

atom/src/atomdict.cpp

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -103,11 +103,7 @@ int AtomDict_traverse( AtomDict* self, visitproc visit, void* arg )
103103
{
104104
Py_VISIT( self->m_key_validator );
105105
Py_VISIT( self->m_value_validator );
106-
#if PY_VERSION_HEX >= 0x03090000
107-
// This was not needed before Python 3.9 (Python issue 35810 and 40217)
108-
Py_VISIT(Py_TYPE(self));
109-
#endif
110-
// PyDict_type is not heap allocated so it does visit the type
106+
Py_VISIT(Py_TYPE(self));
111107
return PyDict_Type.tp_traverse( pyobject_cast( self ), visit, arg );
112108
}
113109

@@ -153,17 +149,21 @@ PyObject* AtomDict_setdefault( AtomDict* self, PyObject* args )
153149
{
154150
return 0;
155151
}
156-
PyObject* value = PyDict_GetItem( pyobject_cast( self ), key );
152+
// Key must be validated before use due to possible coercion in AtomDict_ass_subscript
153+
cppy::ptr key_ptr( validate_key( self, key ) );
154+
if ( !key_ptr )
155+
return 0;
156+
PyObject* value = PyDict_GetItem( pyobject_cast( self ), key_ptr.get() );
157157
if( value )
158158
{
159159
return cppy::incref( value );
160160
}
161-
if( AtomDict_ass_subscript( self, key, dfv ) < 0 )
161+
if( AtomDict_ass_subscript( self, key_ptr.get(), dfv ) < 0 )
162162
{
163163
return 0;
164164
}
165165
// Get the dictionary from the dict itself in case it was coerced.
166-
return cppy::incref( PyDict_GetItem( pyobject_cast( self ), key ) );
166+
return cppy::incref( PyDict_GetItem( pyobject_cast( self ), key_ptr.get() ) );
167167
}
168168

169169

@@ -262,25 +262,26 @@ static PyObject* DefaultAtomDict_repr( DefaultAtomDict* self )
262262
{
263263
return 0;
264264
}
265-
ostr << PyUnicode_AsUTF8( repr.get() );
265+
const char* factory_repr = PyUnicode_AsUTF8( repr.get() );
266+
if ( !factory_repr )
267+
return 0;
268+
ostr << factory_repr;
266269
ostr << ", ";
267270
repr = PyDict_Type.tp_repr( pyobject_cast( self ) );
268271
if( !repr )
269272
{
270273
return 0;
271274
}
272-
ostr << PyUnicode_AsUTF8( repr.get() );
275+
const char* self_repr = PyUnicode_AsUTF8( repr.get() );
276+
if ( !self_repr )
277+
return 0;
278+
ostr << self_repr;
273279
ostr << ")";
274280
return PyUnicode_FromString( ostr.str().c_str() );
275281
}
276282

277-
static PyObject* DefaultAtomDict_missing( DefaultAtomDict* self, PyObject* args )
283+
static PyObject* DefaultAtomDict_missing( DefaultAtomDict* self, PyObject* key )
278284
{
279-
PyObject* key;
280-
if( !PyArg_UnpackTuple( args, "__missing__", 1, 1, &key ) )
281-
{
282-
return 0;
283-
}
284285
CAtom* atom = self->dict.pointer->data();
285286
if( !atom )
286287
{
@@ -289,34 +290,33 @@ static PyObject* DefaultAtomDict_missing( DefaultAtomDict* self, PyObject* args
289290
"so missing value cannot be built."
290291
);
291292
}
292-
#if PY_VERSION_HEX >= 0x03090000
293293
cppy::ptr value_ptr( PyObject_CallOneArg( self->factory, pyobject_cast( atom ) ) );
294-
#else
295-
cppy::ptr temp( PyTuple_Pack(1, pyobject_cast( atom ) ) );
296-
cppy::ptr value_ptr( PyObject_Call( self->factory, temp.get(), 0 ) );
297-
#endif
298294
if( !value_ptr )
299295
{
300296
return 0;
301297
}
302298
if( should_validate( atomdict_cast( self ) ) )
303299
{
300+
// Key must be validated before use due to possible coercion in AtomDict_ass_subscript
301+
cppy::ptr key_ptr( validate_key( atomdict_cast( self ), key ) );
302+
if ( !key_ptr )
303+
return 0;
304304
// We cannot simply validate the value as it will be re-validated when
305305
// it is set which leads to creating a different object.
306-
if( AtomDict_ass_subscript( atomdict_cast( self ), key, value_ptr.get() ) < 0 )
306+
if( AtomDict_ass_subscript( atomdict_cast( self ), key_ptr.get(), value_ptr.get() ) < 0 )
307307
{
308308
return 0;
309309
}
310310
// Get the dictionary from the dict itself in case it was coerced.
311-
value_ptr = cppy::incref( PyDict_GetItem( pyobject_cast( self ), key ) );
311+
value_ptr = cppy::incref( PyDict_GetItem( pyobject_cast( self ), key_ptr.get() ) );
312312
}
313313
return value_ptr.release();
314314
}
315315

316316
static PyMethodDef DefaultAtomDict_methods[] = {
317317
{ "__missing__",
318318
( PyCFunction )DefaultAtomDict_missing,
319-
METH_VARARGS,
319+
METH_O,
320320
"Called when a key is absent from the dictionary" },
321321
{ 0 } // sentinel
322322
};
@@ -370,6 +370,8 @@ PyObject* AtomDict::New( CAtom* atom, Member* key_validator, Member* value_valid
370370
int AtomDict::Update( AtomDict* dict, PyObject* value )
371371
{
372372
cppy::ptr validated_dict( PyDict_New() );
373+
if ( !validated_dict )
374+
return -1; // LCOV_EXCL_LINE (failed dict creation)
373375
PyObject* key;
374376
PyObject* val;
375377
Py_ssize_t index = 0;
@@ -455,10 +457,12 @@ bool DefaultAtomDict::Ready()
455457
{
456458
// This will work only if we create this type after the standard AtomDict
457459
// The reference will be handled by the module to which we will add the type
458-
PyObject* bases = PyTuple_New( 1 );
459-
PyTuple_SET_ITEM( bases, 0, pyobject_cast( AtomDict::TypeObject ) );
460+
cppy::ptr bases( PyTuple_New( 1 ) );
461+
if ( !bases )
462+
return false; // LCOV_EXCL_LINE (failed tuple creation)
463+
PyTuple_SET_ITEM( bases.get(), 0, cppy::incref( pyobject_cast( AtomDict::TypeObject ) ) );
460464
TypeObject = pytype_cast(
461-
PyType_FromSpecWithBases( &TypeObject_Spec, bases )
465+
PyType_FromSpecWithBases( &TypeObject_Spec, bases.get() )
462466
);
463467
if( !TypeObject )
464468
{

tests/test_atomdefaultdict.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,16 @@
1111

1212
import pytest
1313

14-
from atom.api import Atom, DefaultDict, Instance, Int, List, atomlist, defaultatomdict
14+
from atom.api import (
15+
Atom,
16+
Coerced,
17+
DefaultDict,
18+
Instance,
19+
Int,
20+
List,
21+
atomlist,
22+
defaultatomdict,
23+
)
1524

1625

1726
@pytest.fixture
@@ -237,3 +246,11 @@ class A(Atom):
237246
content = a.d[1]
238247
assert isinstance(content, atomlist)
239248
assert content is a.d[1]
249+
250+
251+
def test_coerced_key_missing():
252+
class Obj(Atom):
253+
items = DefaultDict(key=Coerced(str), missing=lambda: "missing")
254+
255+
o = Obj()
256+
o.items[1] # key 1 gets coerced to '1'

tests/test_atomdict.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import pytest
1111

12-
from atom.api import Atom, Dict, Int, List, atomdict, atomlist
12+
from atom.api import Atom, Coerced, Dict, Int, List, atomdict, atomlist
1313

1414

1515
@pytest.fixture
@@ -186,3 +186,13 @@ def test_update(atom_dict):
186186
atom_dict.fullytyped.update({"": 1})
187187
with pytest.raises(TypeError):
188188
atom_dict.fullytyped.update({"": ""})
189+
190+
191+
def test_coerced_setdefault():
192+
class Obj(Atom):
193+
items = Dict(key=Coerced(str))
194+
195+
o = Obj()
196+
o.items["1"] = "a"
197+
o.items.setdefault(1, "b") # key 1 gets coerced to '1'
198+
assert o.items["1"] == "a"

0 commit comments

Comments
 (0)