@@ -101,7 +101,7 @@ def for_header_to_tuple(target, target_type, iter_) -> t.Tuple[
101101 init = typed_ast3 .Assign (targets = [target ], value = begin , type_comment = None )
102102 else :
103103 init = typed_ast3 .AnnAssign (target = target , annotation = target_type , value = begin , simple = True )
104- condition = typed_ast3 .Compare (left = target , ops = [typed_ast3 .Lt ()], comparators = [end ])
104+ condition = typed_ast3 .Expr ( typed_ast3 . Compare (left = target , ops = [typed_ast3 .Lt ()], comparators = [end ]) )
105105 increment = typed_ast3 .AugAssign (target = target , op = typed_ast3 .Add (), value = step )
106106 return init , condition , increment
107107
@@ -133,6 +133,7 @@ def _unsupported_syntax(self, syntax, comment: str = None):
133133 self .fill ('! TODO: unsupported syntax' )
134134 unparsing_unsupported ('C++' , syntax , comment )
135135
136+ # Misusing ast.Index to signify reference types. Should probably fork typed-ast or something. The ast generalizer does something string based.
136137 def dispatch_type (self , type_hint ):
137138 _LOG .debug ('dispatching type hint %s' , type_hint )
138139 if is_generic_type (type_hint ):
@@ -154,8 +155,32 @@ def dispatch_type(self, type_hint):
154155 self .write ('*' )
155156 return
156157 if isinstance (type_hint , typed_ast3 .Subscript ):
158+ < << << << HEAD
157159 _LOG .error ('encountered unsupported subscript form: %s' ,
158160 horast .unparse (type_hint ).strip ())
161+ == == == =
162+ if isinstance (type_hint .value , typed_ast3 .Attribute ) \
163+ and isinstance (type_hint .value .value , typed_ast3 .Name ):
164+ unparsed = horast .unparse (type_hint .value ).strip ()
165+ self .write (PY_TO_CPP_TYPES [unparsed ])
166+ if unparsed == 'st.ndarray' :
167+ self .write ('<' )
168+ sli = type_hint .slice
169+ self .write ('>' )
170+ return
171+ elif isinstance (type_hint .value , typed_ast3 .Name ):
172+ unparsed = horast .unparse (type_hint .value ).strip ()
173+ self .write (unparsed )
174+ self .write ('<' )
175+ if isinstance (type_hint .slice , typed_ast3 .Subscript ):
176+ self .dispatch_type (type_hint .slice )
177+ else :
178+ self .write (horast .unparse (type_hint .slice ).strip ())
179+ self .write (' >' )
180+ if isinstance (type_hint .slice , typed_ast3 .Index ):
181+ self .write ('&' )
182+ return
183+ > >> >> >> sql - to - cpp
159184 self ._unsupported_syntax (type_hint )
160185 if isinstance (type_hint , typed_ast3 .Attribute ):
161186 if isinstance (type_hint .value , typed_ast3 .Name ):
@@ -169,12 +194,19 @@ def dispatch_type(self, type_hint):
169194 assert type_hint .value is None
170195 self .write ('void' )
171196 return
197+
172198 self .dispatch (type_hint )
199+ if isinstance (type_hint , typed_ast3 .Index ):
200+ if isinstance (type_hint .value , typed_ast3 .Name ):
201+ self .write ('&' )
173202
174203 def _Expr (self , tree ):
175204 super ()._Expr (tree )
176205 self .write (';' )
177206
207+ def _Pass (self , tree ):
208+ self .fill ('/* pass */' )
209+
178210 def _Import (self , t ):
179211 self .fill ('/* Python import' )
180212 # raise NotImplementedError('not supported yet')
@@ -183,7 +215,11 @@ def _Import(self, t):
183215 # #include "boost/multi_array.hpp"
184216
185217 def _ImportFrom (self , t ):
186- raise NotImplementedError ('not supported yet' )
218+ self .fill ('/* Python import' )
219+ # raise NotImplementedError('not supported yet')
220+ super ()._ImportFrom (t )
221+ self .fill ('*/' )
222+ # #include "boost/multi_array.hpp"
187223
188224 def _Assign (self , t ):
189225 if self ._context != 'for header' :
@@ -213,12 +249,19 @@ def _AnnAssign(self, t):
213249 self ._unsupported_syntax (t , 'which is not simple' )
214250 self .dispatch_type (t .annotation )
215251 self .write (' ' )
216- self .dispatch (t .target )
252+ try :
253+ self .dispatch (t .target )
254+ except AttributeError as e :
255+ print (e )
217256 if t .value :
218257 self .write (' = ' )
219258 self .dispatch (t .value )
259+ < << << << HEAD
220260 if self ._context != 'for header' :
221261 self .write (';' )
262+ == == == =
263+ self .write (';' )
264+ > >> >> >> sql - to - cpp
222265
223266 def _Return (self , t ):
224267 super ()._Return (t )
@@ -228,6 +271,7 @@ def _Pass(self, t):
228271 self .fill (';' )
229272
230273 def _ClassDef (self , t ):
274+ < << << << HEAD
231275 self .write ('\n ' )
232276 if t .decorator_list :
233277 self ._unsupported_syntax (t , 'with decorators' )
@@ -248,6 +292,25 @@ def _ClassDef(self, t):
248292 else :
249293 comma = True
250294 self .dispatch (e )
295+ == == == =
296+ if len (t .decorator_list ) > 1 :
297+ self ._unsupported_syntax (t , ' with decorators' )
298+ self .write ('\n ' )
299+ self .fill ()
300+ is_struct = False
301+ if len (t .decorator_list ) == 1 :
302+ is_struct = t .decorator_list [0 ].id == 'struct'
303+ if is_struct :
304+ self .write ('struct' )
305+ else :
306+ self .write ('class' )
307+
308+ self .write (' {}' .format (t .name ))
309+ self .enter ()
310+ self .dispatch (t .body )
311+ self .leave ()
312+
313+ > >> >> >> sql - to - cpp
251314
252315 self .enter ()
253316 for stmt in t .body :
@@ -341,6 +404,7 @@ def _AsyncFunctionDef(self, t):
341404
342405 def _For (self , t ):
343406 self .fill ('for (' )
407+ < << << << HEAD
344408 init , cond , increment = for_header_to_tuple (t .target , t .resolved_type_comment , t .iter )
345409 self ._context = 'for header'
346410 self .dispatch (init )
@@ -349,6 +413,18 @@ def _For(self, t):
349413 self .write ('; ' )
350414 self .dispatch (increment )
351415 self ._context = None
416+ == == == =
417+ init , cond , increment = for_header_to_tuple (t .target , t .type_comment , t .iter )
418+ #self.dispatch(init)
419+ #self.dispatch(cond)
420+ #self.dispatch(increment)
421+ unparser = Cpp14Unparser ()
422+ self .write (unparser .unparse (init ).strip ('\n \r ;)(' ))
423+ self .write ('; ' )
424+ self .write (unparser .unparse (cond ).strip ('\n \r ;)(' ))
425+ self .write ('; ' )
426+ self .write (unparser .unparse (increment ).strip ('\n \r ;)(' ))
427+ >> >> >> > sql - to - cpp
352428 # self.dispatch(t.iter)
353429 self .write (')' )
354430 self .enter ()
@@ -383,6 +459,11 @@ def _If(self, t):
383459 self .dispatch (t .orelse )
384460 self .leave ()
385461
462+ < << << << HEAD
463+ == == == =
464+ #raise NotImplementedError('not supported yet')
465+
466+ >> >> >> > sql - to - cpp
386467 def _While (self , t ):
387468 raise NotImplementedError ('not supported yet' )
388469
@@ -446,6 +527,7 @@ def _Compare(self, compare):
446527
447528 def _Attribute (self , t ):
448529 if isinstance (t .value , typed_ast3 .Name ):
530+ < << << << HEAD
449531 if t .value .id == 'self' :
450532 self .write ('this' )
451533 self .write ('->' )
@@ -456,6 +538,22 @@ def _Attribute(self, t):
456538 self ._includes ['cmath' ] = True
457539 self .write (unparsed )
458540 return
541+ == == == =
542+ try :
543+ unparsed = {
544+ ('a' , 'shape' ): '???' ,
545+ ('b' , 'shape' ): '???' ,
546+ ('c' , 'shape' ): '???' ,
547+ ('np' , 'single' ): 'int32_t' ,
548+ ('np' , 'double' ): 'int64_t' ,
549+ ('np' , 'zeros' ): 'boost::multi_array' ,
550+ ('st' , 'ndarray' ): 'boost::multi_array'
551+ }[t .value .id , t .attr ]
552+ self .write (unparsed )
553+ return
554+ except Exception :
555+ pass
556+ >> >> >> > sql - to - cpp
459557 self .dispatch (t .value )
460558 self .write ('.' )
461559 self .write (t .attr )
@@ -501,18 +599,57 @@ def _Call(self, t):
501599
502600 def _arg (self , t ):
503601 if t .annotation is None :
602+ < << << << HEAD
504603 self ._unsupported_syntax (t , 'without annotation' )
604+ == == == =
605+ self ._unsupported_syntax (t , ' without annotation' )
606+ >> >> >> > sql - to - cpp
505607 self .dispatch_type (t .annotation )
506608 self .write (' ' )
507609 self .write (t .arg )
508610
509611 def _Comment (self , node ):
510- if node .eol :
511- self .write (' //' )
612+ if node .value .s .startswith ("pragma" ):
613+ if node .eol :
614+ self .write ('\n #' )
615+ else :
616+ self .fill ('#' )
512617 else :
513- self .fill ('//' )
618+ if node .eol :
619+ self .write (' //' )
620+ else :
621+ self .fill ('//' )
514622 self .write (node .value .s )
515623
624+ < << << << HEAD
625+ == == == =
626+ def _NameConstant (self , node ):
627+ if node .value == 'False' :
628+ self .write ("false" )
629+ elif node .value == 'True' :
630+ self .write ("true" )
631+ elif node .value == 'None' :
632+ self .write ("null" )
633+ else :
634+ self .write (node .value )
635+
636+ def _Str (self , node ):
637+ self .write ('"' )
638+ self .write (node .s .replace ('"' , '\\ "' ).replace ("\0 " , "\\ 0" ).replace ("\n " , "\\ n" ).replace ("\t " , "\\ t" ).replace ("\r " , "\\ r" ))
639+ self .write ('"' )
640+
641+ def _Bytes (self , node ):
642+ # Char
643+ if len (node .s ) == 1 :
644+ self .write ("'" )
645+ self .write (node .s .replace ("'" , "\\ '" ).replace ("\0 " , "\\ 0" ).replace ("\n " , "\\ n" ).replace ("\t " , "\\ t" ).replace ("\r " , "\\ r" ))
646+ self .write ("'" )
647+ else :
648+ self ._Str (node )
649+
650+ def _unsupported_syntax (self , tree , comment : str = '' ):
651+ raise SyntaxError ('unparsing {}{} to C++ is not supported' .format (type (tree ), comment ))
652+ >> >> >> > sql - to - cpp
516653
517654class Cpp14HeaderUnparserBackend (Cpp14UnparserBackend ):
518655
0 commit comments