@@ -78,17 +78,31 @@ def __init__(self, field, args, var, convert=True):
7878
7979
8080class VectorFieldNode (IntrinsicNode ):
81+ def __getattr__ (self , attr ):
82+ if attr == "eval" :
83+ return VectorFieldEvalCallNode (self )
84+ else :
85+ raise NotImplementedError ('Access to VectorField attributes are not (yet) implemented in JIT mode' )
86+
8187 def __getitem__ (self , attr ):
8288 return VectorFieldEvalNode (self .obj , attr )
8389
8490
91+ class VectorFieldEvalCallNode (IntrinsicNode ):
92+ def __init__ (self , field ):
93+ self .field = field
94+ self .obj = field .obj
95+ self .ccode = ""
96+
97+
8598class VectorFieldEvalNode (IntrinsicNode ):
86- def __init__ (self , field , args , var , var2 , var3 ):
99+ def __init__ (self , field , args , var , var2 , var3 , convert = True ):
87100 self .field = field
88101 self .args = args
89102 self .var = var # the variable in which the interpolated field is written
90103 self .var2 = var2 # second variable for UV interpolation
91104 self .var3 = var3 # third variable for UVW interpolation
105+ self .convert = convert # whether to convert the result (like field.applyConversion)
92106
93107
94108class SummedFieldNode (IntrinsicNode ):
@@ -412,6 +426,26 @@ def visit_Call(self, node):
412426 self .stmt_stack += [FieldEvalNode (node .func .field , args , tmp , convert )]
413427 return ast .Name (id = tmp )
414428
429+ elif isinstance (node .func , VectorFieldEvalCallNode ):
430+ # get a temporary value to assign result to
431+ tmp1 = self .get_tmp ()
432+ tmp2 = self .get_tmp ()
433+ tmp3 = self .get_tmp () if node .func .field .obj .vector_type == '3D' else None
434+ # whether to convert
435+ convert = True
436+ if "applyConversion" in node .keywords :
437+ k = node .keywords ["applyConversion" ]
438+ if isinstance (k , ast .NameConstant ):
439+ convert = k .value
440+
441+ # convert args to Index(Tuple(*args))
442+ args = ast .Index (value = ast .Tuple (node .args , ast .Load ()))
443+
444+ self .stmt_stack += [VectorFieldEvalNode (node .func .field , args , tmp1 , tmp2 , tmp3 , convert )]
445+ if tmp3 :
446+ return ast .Tuple ([ast .Name (id = tmp1 ), ast .Name (id = tmp2 ), ast .Name (id = tmp3 )], ast .Load ())
447+ else :
448+ return ast .Tuple ([ast .Name (id = tmp1 ), ast .Name (id = tmp2 )], ast .Load ())
415449 return node
416450
417451
@@ -907,14 +941,14 @@ def visit_VectorFieldEvalNode(self, node):
907941 args = self ._check_FieldSamplingArguments (node .args .ccode )
908942 ccode_eval = node .field .obj .ccode_eval_array (node .var , node .var2 , node .var3 ,
909943 node .field .obj .U , node .field .obj .V , node .field .obj .W , * args )
910- if node .field .obj .U .interp_method != 'cgrid_velocity' :
944+ if node .convert and node . field .obj .U .interp_method != 'cgrid_velocity' :
911945 ccode_conv1 = node .field .obj .U .ccode_convert (* args )
912946 ccode_conv2 = node .field .obj .V .ccode_convert (* args )
913947 statements = [c .Statement ("%s *= %s" % (node .var , ccode_conv1 )),
914948 c .Statement ("%s *= %s" % (node .var2 , ccode_conv2 ))]
915949 else :
916950 statements = []
917- if node .field .obj .vector_type == '3D' :
951+ if node .convert and node . field .obj .vector_type == '3D' :
918952 ccode_conv3 = node .field .obj .W .ccode_convert (* args )
919953 statements .append (c .Statement ("%s *= %s" % (node .var3 , ccode_conv3 )))
920954 conv_stat = c .Block (statements )
@@ -1058,14 +1092,14 @@ def visit_VectorFieldEvalNode(self, node):
10581092 self .visit (node .args )
10591093 args = self ._check_FieldSamplingArguments (node .args .ccode )
10601094 ccode_eval = node .field .obj .ccode_eval_object (node .var , node .var2 , node .var3 , node .field .obj .U , node .field .obj .V , node .field .obj .W , * args )
1061- if node .field .obj .U .interp_method != 'cgrid_velocity' :
1095+ if node .convert and node . field .obj .U .interp_method != 'cgrid_velocity' :
10621096 ccode_conv1 = node .field .obj .U .ccode_convert (* args )
10631097 ccode_conv2 = node .field .obj .V .ccode_convert (* args )
10641098 statements = [c .Statement ("%s *= %s" % (node .var , ccode_conv1 )),
10651099 c .Statement ("%s *= %s" % (node .var2 , ccode_conv2 ))]
10661100 else :
10671101 statements = []
1068- if node .field .obj .vector_type == '3D' :
1102+ if node .convert and node . field .obj .vector_type == '3D' :
10691103 ccode_conv3 = node .field .obj .W .ccode_convert (* args )
10701104 statements .append (c .Statement ("%s *= %s" % (node .var3 , ccode_conv3 )))
10711105 conv_stat = c .Block (statements )
0 commit comments