Skip to content

Commit cab6950

Browse files
committed
Improve record type checking
Signed-off-by: Ben Sherman <bentshermann@gmail.com>
1 parent 05dccba commit cab6950

5 files changed

Lines changed: 104 additions & 32 deletions

File tree

src/main/java/nextflow/script/control/DataflowOpResolver.java

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import nextflow.script.types.Channel;
2424
import nextflow.script.types.Record;
2525
import nextflow.script.types.Tuple;
26+
import nextflow.script.types.TypesEx;
2627
import nextflow.script.types.Value;
2728
import org.codehaus.groovy.ast.ClassHelper;
2829
import org.codehaus.groovy.ast.ClassNode;
@@ -46,6 +47,12 @@ class DataflowOpResolver {
4647
private static final ClassNode TUPLE_TYPE = ClassHelper.makeCached(Tuple.class);
4748
private static final ClassNode VALUE_TYPE = ClassHelper.makeCached(Value.class);
4849

50+
private ClassNode receiverType;
51+
52+
public DataflowOpResolver(ClassNode receiverType) {
53+
this.receiverType = receiverType;
54+
}
55+
4956
/**
5057
* Resolve the return type of dataflow operators where applicable,
5158
* such as `combine`, `groupBy`, and `join`.
@@ -96,7 +103,7 @@ private ClassNode applyCombine(ClassNode lhsType, List<Expression> arguments) {
96103
}
97104

98105
private ClassNode applyCombineNamedArgs(ClassNode lhsType, NamedArgumentListExpression nale) {
99-
if( !isRecordType(lhsType) )
106+
if( !TypesEx.isRecordType(lhsType) )
100107
return ClassHelper.dynamicType();
101108
var rhsType = new ClassNode(Record.class);
102109
for( var entry : nale.getMapEntryExpressions() ) {
@@ -108,7 +115,7 @@ private ClassNode applyCombineNamedArgs(ClassNode lhsType, NamedArgumentListExpr
108115
rhsType.addField(fn);
109116
}
110117
var elementType = recordSumType(lhsType, rhsType);
111-
return makeType(CHANNEL_TYPE, elementType);
118+
return makeType(receiverType, elementType);
112119
}
113120

114121
private static ClassNode dataflowValueType(ClassNode type) {
@@ -119,7 +126,7 @@ private static ClassNode dataflowValueType(ClassNode type) {
119126
return type;
120127
}
121128

122-
private boolean combineTupleOrValue(List<ClassNode> componentTypes, ClassNode type) {
129+
private static boolean combineTupleOrValue(List<ClassNode> componentTypes, ClassNode type) {
123130
if( TUPLE_TYPE.equals(type) ) {
124131
var gts = type.getGenericsTypes();
125132
if( gts == null && gts.length == 0 )
@@ -133,6 +140,12 @@ private boolean combineTupleOrValue(List<ClassNode> componentTypes, ClassNode ty
133140
return true;
134141
}
135142

143+
private ClassNode channelTupleType(GenericsType[] gts) {
144+
var tupleType = TUPLE_TYPE.getPlainNodeReference();
145+
tupleType.setGenericsTypes(gts);
146+
return makeType(receiverType, tupleType);
147+
}
148+
136149
/**
137150
* Resolve the result type of a `groupBy` operation.
138151
*
@@ -167,11 +180,11 @@ private ClassNode applyGroupBy(ClassNode lhsType, List<Expression> arguments) {
167180
* @param arguments
168181
*/
169182
private ClassNode applyJoin(ClassNode lhsType, List<Expression> arguments) {
170-
if( !isRecordType(lhsType) )
183+
if( !TypesEx.isRecordType(lhsType) )
171184
return ClassHelper.dynamicType();
172185
var argType = getType(arguments.get(arguments.size() - 1));
173186
var rhsType = dataflowElementType(argType);
174-
if( !isRecordType(rhsType) )
187+
if( !TypesEx.isRecordType(rhsType) )
175188
return ClassHelper.dynamicType();
176189
// TODO: report error if `by` field is not in both records
177190
var elementType = recordSumType(lhsType, rhsType);
@@ -184,10 +197,4 @@ private static ClassNode dataflowElementType(ClassNode type) {
184197
return ClassHelper.dynamicType();
185198
}
186199

187-
private static ClassNode channelTupleType(GenericsType[] gts) {
188-
var tupleType = TUPLE_TYPE.getPlainNodeReference();
189-
tupleType.setGenericsTypes(gts);
190-
return makeType(CHANNEL_TYPE, tupleType);
191-
}
192-
193200
}

src/main/java/nextflow/script/control/TypeCheckingVisitorEx.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -536,7 +536,7 @@ private void checkOperatorCall(MethodCallExpression node) {
536536
return;
537537

538538
var receiverType = getType(node.getObjectExpression());
539-
if( !CHANNEL_TYPE.equals(receiverType) )
539+
if( !CHANNEL_TYPE.equals(receiverType) && !VALUE_TYPE.equals(receiverType) )
540540
return;
541541

542542
var method = (MethodNode) node.getNodeMetaData(ASTNodeMarker.METHOD_TARGET);
@@ -545,7 +545,7 @@ private void checkOperatorCall(MethodCallExpression node) {
545545

546546
var lhsType = elementType(receiverType);
547547
var arguments = asMethodCallArguments(node);
548-
var resultType = new DataflowOpResolver().apply(lhsType, method, arguments);
548+
var resultType = new DataflowOpResolver(receiverType).apply(lhsType, method, arguments);
549549
if( ClassHelper.isDynamicTyped(resultType) )
550550
return;
551551

@@ -970,7 +970,7 @@ private boolean checkRecordSum(BinaryExpression node) {
970970
var rhs = node.getRightExpression();
971971
var lhsType = getType(lhs);
972972
var rhsType = getType(rhs);
973-
if( !(isRecordType(lhsType) && isRecordType(rhsType)) )
973+
if( !TypesEx.isRecordType(lhsType) || !TypesEx.isRecordType(rhsType) )
974974
return false;
975975

976976
var resultType = recordSumType(lhsType, rhsType);
@@ -1179,7 +1179,7 @@ public void visitCastExpression(CastExpression node) {
11791179
private boolean checkRecordCast(ClassNode targetType, ClassNode sourceType, ASTNode node) {
11801180
if( !(targetType.redirect() instanceof RecordNode) )
11811181
return false;
1182-
if( !isRecordType(sourceType) )
1182+
if( !TypesEx.isRecordType(sourceType) )
11831183
return false;
11841184
for( var target : targetType.getFields() ) {
11851185
if( target.getType().getNodeMetaData(ASTNodeMarker.NULLABLE) != null )

src/main/java/nextflow/script/types/TypeCheckingUtils.java

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -878,16 +878,6 @@ public static ClassNode elementType(ClassNode type) {
878878
return gts[0].getType();
879879
}
880880

881-
/**
882-
* Determine whether a type is a record type, either
883-
* as Record or a user-defined record type.
884-
*
885-
* @param type
886-
*/
887-
public static boolean isRecordType(ClassNode type) {
888-
return RECORD_TYPE.equals(type) || type.redirect() instanceof RecordNode;
889-
}
890-
891881
/**
892882
* Return a record type for the sum of two records.
893883
*

src/main/java/nextflow/script/types/TypesEx.java

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@
5151
*/
5252
public class TypesEx {
5353

54+
private static final ClassNode PARAMS_TYPE = ClassHelper.makeCached(ParamsMap.class);
55+
private static final ClassNode RECORD_TYPE = ClassHelper.makeCached(Record.class);
56+
5457
/**
5558
* Determine whether a method has a non-void return type.
5659
*
@@ -75,20 +78,15 @@ public static boolean hasReturnType(MethodNode node) {
7578
public static boolean isAssignableFrom(ClassNode target, ClassNode source, boolean checkGenerics) {
7679
if( ClassHelper.isObjectType(target) || ClassHelper.isDynamicTyped(source) )
7780
return true;
81+
if( isRecordType(target) && (PARAMS_TYPE.equals(source) || isRecordType(source)) )
82+
return isAssignableFromRecord(target, source);
7883
if( target.equals(source) )
7984
return true;
80-
if( RECORD_TYPE.equals(target) && source.redirect() instanceof RecordNode )
81-
return true;
82-
if( target.redirect() instanceof RecordNode && (PARAMS_TYPE.equals(source) || RECORD_TYPE.equals(source)) )
83-
return isAssignableFromRecord(target, source);
8485
return target.isResolved() && source.isResolved()
8586
&& isAssignableFrom(target.getTypeClass(), source.getTypeClass())
8687
&& (!checkGenerics || isAssignableFrom(target.getGenericsTypes(), source.getGenericsTypes()));
8788
}
8889

89-
private static final ClassNode PARAMS_TYPE = ClassHelper.makeCached(ParamsMap.class);
90-
private static final ClassNode RECORD_TYPE = ClassHelper.makeCached(Record.class);
91-
9290
private static boolean isAssignableFromRecord(ClassNode target, ClassNode source) {
9391
for( var targetFn : target.getFields() ) {
9492
if( targetFn.getType().getNodeMetaData(ASTNodeMarker.NULLABLE) != null )
@@ -192,6 +190,16 @@ public static boolean isNamespace(ClassNode cn) {
192190
return cn.implementsInterface(ClassHelper.makeCached(Namespace.class));
193191
}
194192

193+
/**
194+
* Determine whether a type is a record type, either
195+
* as Record or a user-defined record type.
196+
*
197+
* @param type
198+
*/
199+
public static boolean isRecordType(ClassNode type) {
200+
return RECORD_TYPE.equals(type) || type.redirect() instanceof RecordNode;
201+
}
202+
195203
/**
196204
* Get the display name of a type.
197205
*

src/test/groovy/nextflow/script/types/TypeCheckingTest.groovy

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -822,6 +822,47 @@ class TypeCheckingTest extends Specification {
822822
)
823823
}
824824

825+
def 'should check a process call with record inputs' () {
826+
expect:
827+
check(
828+
'''\
829+
nextflow.preview.types = true
830+
831+
process hello {
832+
input:
833+
record(id: String, fastq: Path)
834+
835+
exec:
836+
println '...'
837+
}
838+
839+
workflow {
840+
hello( record(id: '1') )
841+
}
842+
''',
843+
'Argument with type Record {\n id: String\n} is not compatible with process input of type Record {\n id: String\n fastq: Path\n}'
844+
)
845+
and:
846+
check(
847+
'''\
848+
nextflow.preview.types = true
849+
850+
process hello {
851+
input:
852+
record(id: String, fastq: Path)
853+
854+
exec:
855+
println '...'
856+
}
857+
858+
workflow {
859+
hello( record(id: '1', fastq: file('1.fastq')) )
860+
}
861+
''',
862+
null
863+
)
864+
}
865+
825866
def 'should recognize process output type' () {
826867
when:
827868
def exp = parseExpression(
@@ -1083,6 +1124,32 @@ class TypeCheckingTest extends Specification {
10831124
TypesEx.getName(type) == 'Channel<Record {\n id: Integer\n fastq: Path\n single_end: Boolean\n index: Path\n}>'
10841125
}
10851126

1127+
def 'should resolve a `combine` operation on a dataflow value' () {
1128+
when:
1129+
def exp = parseExpression(
1130+
'''\
1131+
left = channel.value( 42 )
1132+
right = channel.value( 'hello' )
1133+
left.combine(right)
1134+
'''
1135+
)
1136+
def type = getType(exp)
1137+
then:
1138+
TypesEx.getName(type) == 'Value<Tuple<Integer, String>>'
1139+
1140+
when:
1141+
exp = parseExpression(
1142+
'''\
1143+
sample = channel.value( record(id: 1, fastq: file('1.fq')) )
1144+
index = channel.value( file('index.fa') )
1145+
sample.combine( single_end: true, index: index )
1146+
'''
1147+
)
1148+
type = getType(exp)
1149+
then:
1150+
TypesEx.getName(type) == 'Value<Record {\n id: Integer\n fastq: Path\n single_end: Boolean\n index: Path\n}>'
1151+
}
1152+
10861153
def 'should resolve a `groupBy` operation' () {
10871154
when:
10881155
def exp = parseExpression(

0 commit comments

Comments
 (0)