Skip to content

Commit 57de24d

Browse files
committed
Update sortedmap
1 parent d9454af commit 57de24d

2 files changed

Lines changed: 110 additions & 45 deletions

File tree

atom/src/sortedmap.cpp

Lines changed: 70 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,48 @@ struct SortedMap
131131

132132
static bool Ready();
133133

134+
static PyObject* New( PyTypeObject* type, PyObject* map )
135+
{
136+
cppy::ptr selfptr( PyType_GenericNew( type, 0, 0 ) );
137+
if( !selfptr ) {
138+
return 0; // LCOV_EXCL_LINE (allocation failed, very unlikely)
139+
}
140+
SortedMap* self = reinterpret_cast<SortedMap*>( selfptr.get() );
141+
self->m_items = new SortedMap::Items();
142+
143+
if( map )
144+
{
145+
if( PyDict_Check(map) )
146+
{
147+
PyObject* key;
148+
PyObject* val;
149+
Py_ssize_t index = 0;
150+
while( PyDict_Next( map, &index, &key, &val ) )
151+
self->setitem( key, val );
152+
}
153+
else {
154+
cppy::ptr iter( PyObject_GetIter( map ) );
155+
if( !iter )
156+
return 0;
157+
cppy::ptr item;
158+
while( (item = PyIter_Next( iter.get() )) )
159+
{
160+
cppy::ptr pair( PySequence_Fast( item.get(), "map must be a sequence of key, value pairs") );
161+
if ( !pair )
162+
return 0;
163+
if( PySequence_Fast_GET_SIZE( pair.get() ) != 2 )
164+
return cppy::type_error( pair.get(), "pairs of objects" );
165+
self->setitem( PySequence_Fast_GET_ITEM( pair.get(), 0 ),
166+
PySequence_Fast_GET_ITEM( pair.get(), 1 ) );
167+
}
168+
if ( PyErr_Occurred() )
169+
return 0; // error during iteration
170+
}
171+
}
172+
173+
return selfptr.release();
174+
}
175+
134176
PyObject* getitem( PyObject* key, PyObject* default_value = 0 )
135177
{
136178
Items::iterator it = std::lower_bound(
@@ -288,46 +330,22 @@ SortedMap_new( PyTypeObject* type, PyObject* args, PyObject* kwargs )
288330
static char* kwlist[] = { "map", 0 };
289331
if( !PyArg_ParseTupleAndKeywords( args, kwargs, "|O:__new__", kwlist, &map ) )
290332
return 0;
333+
return SortedMap::New( type, map );
334+
}
291335

292-
PyObject* self = PyType_GenericNew( type, 0, 0 );
293-
if( !self ) {
294-
return 0; // LCOV_EXCL_LINE (allocation failed, very unlikely)
295-
}
296-
SortedMap* cself = reinterpret_cast<SortedMap*>( self );
297-
cself->m_items = new SortedMap::Items();
298-
299-
cppy::ptr seq;
300-
if( map )
301-
{
302-
if( PyDict_Check( map ) )
303-
{
304-
seq = PyObject_GetIter( PyDict_Items( map ) );
305-
if( !seq ) {
306-
return 0; // LCOV_EXCL_LINE (dict items failed, very unlikely)
307-
}
308-
}
309-
else
310-
{
311-
seq = PyObject_GetIter( map );
312-
if( !seq )
313-
return 0;
314-
}
315-
}
316-
317-
if( seq )
318-
{
319-
cppy::ptr item;
320-
while( (item = PyIter_Next( seq.get() )) )
321-
{
322-
if( PySequence_Length( item.get() ) != 2)
323-
return cppy::type_error( item.get(), "pairs of objects" );
324-
325-
cself->setitem( PySequence_GetItem( item.get(), 0 ),
326-
PySequence_GetItem( item.get(), 1 ) );
327-
}
336+
PyObject*
337+
SortedMap_vectorcall( PyObject* type, PyObject*const *args, size_t nargsf, PyObject* kwnames )
338+
{
339+
if ( kwnames )
340+
return cppy::type_error("sortedmap takes no kwargs");
341+
switch (PyVectorcall_NARGS(nargsf)) {
342+
case 0:
343+
return SortedMap::New( reinterpret_cast<PyTypeObject*>(type), 0 );
344+
case 1:
345+
return SortedMap::New( reinterpret_cast<PyTypeObject*>(type), args[0] );
346+
default:
347+
return cppy::type_error("sortedmap takes at most one argument");
328348
}
329-
330-
return self;
331349
}
332350

333351
// Clearing the vector may cause arbitrary side effects on item
@@ -353,22 +371,21 @@ SortedMap_traverse( SortedMap* self, visitproc visit, void* arg )
353371
Py_VISIT( it->key() );
354372
Py_VISIT( it->value() );
355373
}
356-
#if PY_VERSION_HEX >= 0x03090000
357-
// This was not needed before Python 3.9 (Python issue 35810 and 40217)
358374
Py_VISIT(Py_TYPE(self));
359-
#endif
360375
return 0;
361376
}
362377

363378

364379
void
365380
SortedMap_dealloc( SortedMap* self )
366381
{
382+
PyTypeObject *tp = Py_TYPE(self);
367383
PyObject_GC_UnTrack( self );
368384
SortedMap_clear( self );
369385
delete self->m_items;
370386
self->m_items = 0;
371-
Py_TYPE(self)->tp_free( reinterpret_cast<PyObject*>( self ) );
387+
tp->tp_free( pyobject_cast( self ) );
388+
Py_DECREF(tp);
372389
}
373390

374391

@@ -508,8 +525,14 @@ SortedMap_repr( SortedMap* self )
508525
cppy::ptr valstr( PyObject_Repr( it->value() ) );
509526
if( !valstr )
510527
return 0;
511-
ostr << "(" << PyUnicode_AsUTF8( keystr.get() ) << ", ";
512-
ostr << PyUnicode_AsUTF8( valstr.get() ) << "), ";
528+
const char* k = PyUnicode_AsUTF8( keystr.get() );
529+
if ( !k )
530+
return 0;
531+
const char* v = PyUnicode_AsUTF8( valstr.get() );
532+
if ( !v )
533+
return 0;
534+
ostr << "(" << k << ", ";
535+
ostr << v << "), ";
513536
}
514537
if( self->m_items->size() > 0 )
515538
ostr.seekp( -2, std::ios_base::cur );
@@ -574,6 +597,9 @@ static PyType_Slot SortedMap_Type_slots[] = {
574597
{ Py_tp_new, void_cast( SortedMap_new ) }, /* tp_new */
575598
{ Py_tp_iter, void_cast( SortedMap_iter ) }, /* tp_iter */
576599
{ Py_tp_alloc, void_cast( PyType_GenericAlloc ) }, /* tp_alloc */
600+
#if defined(Py_tp_vectorcall)
601+
{ Py_tp_vectorcall, void_cast( SortedMap_vectorcall ) }, /* tp_vectorcall */
602+
#endif
577603
{ Py_mp_length, void_cast( SortedMap_length ) }, /* mp_length */
578604
{ Py_mp_subscript, void_cast( SortedMap_subscript ) }, /* mp_subscript */
579605
{ Py_mp_ass_subscript, void_cast( SortedMap_ass_subscript ) }, /* mp_ass_subscript */

tests/datastructure/test_sortedmap.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
"""Test the sortedmap that acts like an ordered dictionary."""
99

1010
import gc
11+
import sys
1112

1213
import pytest
1314

@@ -36,9 +37,47 @@ def test_sortedmap_init():
3637

3738
with pytest.raises(TypeError):
3839
sortedmap(1)
40+
with pytest.raises(TypeError):
41+
sortedmap(a=1)
42+
with pytest.raises(TypeError):
43+
sortedmap(1, 2)
3944
with pytest.raises(TypeError) as excinfo:
4045
sortedmap([1])
41-
assert "pairs" in excinfo.exconly()
46+
assert "sequence of key, value pairs" in excinfo.exconly()
47+
with pytest.raises(TypeError) as excinfo:
48+
sortedmap([[1]])
49+
assert "pairs of objects" in excinfo.exconly()
50+
51+
52+
def test_sortedmap_gen_err():
53+
"""Test that iterator error is raised"""
54+
55+
def generator(throw):
56+
yield ("a", 1)
57+
if throw:
58+
raise ValueError()
59+
60+
smap = sortedmap(generator(throw=False))
61+
assert smap["a"] == 1
62+
with pytest.raises(ValueError):
63+
smap = sortedmap(generator(throw=True))
64+
65+
66+
def test_sortedmap_refcnt():
67+
"""Test that constructor does not leak references"""
68+
k = object()
69+
v = object()
70+
rck = sys.getrefcount(k)
71+
rcv = sys.getrefcount(v)
72+
smap = sortedmap([(k, v)])
73+
assert smap[k] == v
74+
del smap
75+
smap = sortedmap({k: v})
76+
assert smap[k] == v
77+
del smap
78+
gc.collect()
79+
assert sys.getrefcount(k) == rck
80+
assert sys.getrefcount(v) == rcv
4281

4382

4483
def test_traverse():

0 commit comments

Comments
 (0)