diff --git a/src/main/java/org/rumbledb/api/Item.java b/src/main/java/org/rumbledb/api/Item.java index 01d5f3f270..c1e790e7be 100644 --- a/src/main/java/org/rumbledb/api/Item.java +++ b/src/main/java/org/rumbledb/api/Item.java @@ -439,6 +439,15 @@ default public RuntimeIterator getBodyIterator() { throw new UnsupportedOperationException("Operation not defined for type " + this.getDynamicType()); } + /** + * Returns the body iterator, if it is a function item. + * + * @return the function signature. + */ + default public Map getBodyIterators() { + throw new UnsupportedOperationException("Operation not defined for type " + this.getDynamicType()); + } + /** * Returns the local variable bindings, if it is a function item. * diff --git a/src/main/java/org/rumbledb/compiler/RuntimeIteratorVisitor.java b/src/main/java/org/rumbledb/compiler/RuntimeIteratorVisitor.java index 35417fd49d..876d9f6c2f 100644 --- a/src/main/java/org/rumbledb/compiler/RuntimeIteratorVisitor.java +++ b/src/main/java/org/rumbledb/compiler/RuntimeIteratorVisitor.java @@ -142,6 +142,7 @@ import org.rumbledb.types.SequenceType; import java.util.ArrayList; +import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -497,12 +498,15 @@ public RuntimeIterator visitInlineFunctionExpr(InlineFunctionExpression expressi paramNameToSequenceTypes.put(paramEntry.getKey(), paramEntry.getValue()); } SequenceType returnType = expression.getReturnType(); - RuntimeIterator bodyIterator = this.visit(expression.getBody(), argument); + Map bodyIterators = new HashMap<>(); + for (long l : expression.getBodies().keySet()) { + bodyIterators.put(l, this.visit(expression.getBodies().get(l), argument)); + } RuntimeIterator runtimeIterator = new FunctionRuntimeIterator( expression.getName(), paramNameToSequenceTypes, returnType, - bodyIterator, + bodyIterators, expression.getHighestExecutionMode(this.visitorConfig), expression.getMetadata() ); diff --git a/src/main/java/org/rumbledb/compiler/StaticContextVisitor.java b/src/main/java/org/rumbledb/compiler/StaticContextVisitor.java index 3d95657546..bb87b2ea12 100644 --- a/src/main/java/org/rumbledb/compiler/StaticContextVisitor.java +++ b/src/main/java/org/rumbledb/compiler/StaticContextVisitor.java @@ -24,6 +24,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Map.Entry; import org.rumbledb.context.BuiltinFunctionCatalogue; import org.rumbledb.context.FunctionIdentifier; @@ -191,23 +192,45 @@ public StaticContext visitFunctionDeclaration(FunctionDeclaration declaration, S @Override public StaticContext visitInlineFunctionExpr(InlineFunctionExpression expression, StaticContext argument) { // define a static context for the function body, add params to the context and visit the body expression - StaticContext functionDeclarationContext = new StaticContext(argument); - expression.getParams() - .forEach( - (paramName, sequenceType) -> functionDeclarationContext.addVariable( - paramName, - sequenceType, - expression.getMetadata(), - ExecutionMode.LOCAL - ) + StaticContext functionDeclarationContextLocal = new StaticContext(argument); + for (Entry entry : expression.getParams().entrySet()) { + functionDeclarationContextLocal.addVariable( + entry.getKey(), + entry.getValue(), + expression.getMetadata(), + ExecutionMode.LOCAL ); + } // visit the body first to make its execution mode available while adding the function to the catalog - this.visit(expression.getBody(), functionDeclarationContext); + this.visit(expression.getBodies().get(0L), functionDeclarationContextLocal); + + StaticContext functionDeclarationContextRDD = new StaticContext(argument); + boolean first = true; + for (Entry entry : expression.getParams().entrySet()) { + if (first) { + functionDeclarationContextRDD.addVariable( + entry.getKey(), + entry.getValue(), + expression.getMetadata(), + ExecutionMode.DATAFRAME + ); + first = false; + } else { + functionDeclarationContextRDD.addVariable( + entry.getKey(), + entry.getValue(), + expression.getMetadata(), + ExecutionMode.LOCAL + ); + } + } + StaticContext functionDeclarationContextDF = new StaticContext(argument); + this.visit(expression.getBodies().get(1L), functionDeclarationContextDF); expression.initHighestExecutionMode(this.visitorConfig); expression.registerUserDefinedFunctionExecutionMode( this.visitorConfig ); - return functionDeclarationContext; + return argument; } @Override diff --git a/src/main/java/org/rumbledb/compiler/TranslationVisitor.java b/src/main/java/org/rumbledb/compiler/TranslationVisitor.java index f39e9e2ed7..b7037fbcd6 100644 --- a/src/main/java/org/rumbledb/compiler/TranslationVisitor.java +++ b/src/main/java/org/rumbledb/compiler/TranslationVisitor.java @@ -396,7 +396,7 @@ public Name parseName(JsoniqParser.QnameContext ctx, boolean isFunction, boolean @Override public Node visitFunctionDecl(JsoniqParser.FunctionDeclContext ctx) { Name name = parseName(ctx.qname(), true, false); - Map fnParams = new LinkedHashMap<>(); + LinkedHashMap fnParams = new LinkedHashMap<>(); SequenceType fnReturnType = null; Name paramName; SequenceType paramType; @@ -1371,7 +1371,7 @@ public Node visitNamedFunctionRef(JsoniqParser.NamedFunctionRefContext ctx) { @Override public Node visitInlineFunctionExpr(JsoniqParser.InlineFunctionExprContext ctx) { - Map fnParams = new LinkedHashMap<>(); + LinkedHashMap fnParams = new LinkedHashMap<>(); SequenceType fnReturnType = SequenceType.MOST_GENERAL_SEQUENCE_TYPE; Name paramName; SequenceType paramType; diff --git a/src/main/java/org/rumbledb/compiler/VariableDependenciesVisitor.java b/src/main/java/org/rumbledb/compiler/VariableDependenciesVisitor.java index 5b3cf524b5..e0dd59a8d7 100644 --- a/src/main/java/org/rumbledb/compiler/VariableDependenciesVisitor.java +++ b/src/main/java/org/rumbledb/compiler/VariableDependenciesVisitor.java @@ -352,7 +352,9 @@ public Void visitContextExpr(ContextItemExpression expression, Void argument) { @Override public Void visitInlineFunctionExpr(InlineFunctionExpression expression, Void argument) { - visit(expression.getBody(), null); + for (long l : expression.getBodies().keySet()) { + visit(expression.getBodies().get(l), null); + } addInputVariableDependencies(expression, getInputVariableDependencies(expression.getBody())); removeInputVariableDependencies(expression, expression.getParams().keySet()); return null; diff --git a/src/main/java/org/rumbledb/compiler/XQueryTranslationVisitor.java b/src/main/java/org/rumbledb/compiler/XQueryTranslationVisitor.java index 37b20ce65d..c76066b4c4 100644 --- a/src/main/java/org/rumbledb/compiler/XQueryTranslationVisitor.java +++ b/src/main/java/org/rumbledb/compiler/XQueryTranslationVisitor.java @@ -371,7 +371,7 @@ private void processAnnotations(XQueryParser.AnnotationsContext annotations) { @Override public Node visitFunctionDecl(XQueryParser.FunctionDeclContext ctx) { Name name = parseName(ctx.eqName(), true, false); - Map fnParams = new LinkedHashMap<>(); + LinkedHashMap fnParams = new LinkedHashMap<>(); SequenceType fnReturnType = MOST_GENERAL_SEQUENCE_TYPE; Name paramName; SequenceType paramType; @@ -1279,7 +1279,7 @@ public Node visitNamedFunctionRef(XQueryParser.NamedFunctionRefContext ctx) { @Override public Node visitInlineFunctionRef(XQueryParser.InlineFunctionRefContext ctx) { - Map fnParams = new LinkedHashMap<>(); + LinkedHashMap fnParams = new LinkedHashMap<>(); SequenceType fnReturnType = SequenceType.MOST_GENERAL_SEQUENCE_TYPE; Name paramName; SequenceType paramType; diff --git a/src/main/java/org/rumbledb/expressions/CommaExpression.java b/src/main/java/org/rumbledb/expressions/CommaExpression.java index 8bebe5e981..8143244050 100644 --- a/src/main/java/org/rumbledb/expressions/CommaExpression.java +++ b/src/main/java/org/rumbledb/expressions/CommaExpression.java @@ -29,6 +29,8 @@ public class CommaExpression extends Expression { + private static final long serialVersionUID = 1L; + private final List expressions; public CommaExpression(List expressions, ExceptionMetadata metadata) { diff --git a/src/main/java/org/rumbledb/expressions/Expression.java b/src/main/java/org/rumbledb/expressions/Expression.java index 9cc6c5529d..c4d5d2700e 100644 --- a/src/main/java/org/rumbledb/expressions/Expression.java +++ b/src/main/java/org/rumbledb/expressions/Expression.java @@ -36,7 +36,8 @@ */ public abstract class Expression extends Node { - protected StaticContext staticContext; + private static final long serialVersionUID = 1L; + protected transient StaticContext staticContext; protected SequenceType inferredSequenceType; diff --git a/src/main/java/org/rumbledb/expressions/Node.java b/src/main/java/org/rumbledb/expressions/Node.java index 4105b8b561..6a1d5f839f 100644 --- a/src/main/java/org/rumbledb/expressions/Node.java +++ b/src/main/java/org/rumbledb/expressions/Node.java @@ -24,6 +24,7 @@ import org.rumbledb.exceptions.ExceptionMetadata; import org.rumbledb.exceptions.OurBadException; +import java.io.Serializable; import java.util.ArrayList; import java.util.List; import java.util.function.Predicate; @@ -32,8 +33,9 @@ * This is the top-level class for nodes in the intermediate representation of a * JSONiq query. Nodes include expressions, clauses, function declarations, etc. */ -public abstract class Node { +public abstract class Node implements Serializable { + private static final long serialVersionUID = 1L; private ExceptionMetadata metadata; protected ExecutionMode highestExecutionMode = ExecutionMode.UNSET; diff --git a/src/main/java/org/rumbledb/expressions/arithmetic/AdditiveExpression.java b/src/main/java/org/rumbledb/expressions/arithmetic/AdditiveExpression.java index c99b7c2f85..babb758e19 100644 --- a/src/main/java/org/rumbledb/expressions/arithmetic/AdditiveExpression.java +++ b/src/main/java/org/rumbledb/expressions/arithmetic/AdditiveExpression.java @@ -30,6 +30,8 @@ import java.util.List; public class AdditiveExpression extends Expression { + private static final long serialVersionUID = 1L; + private Expression leftExpression; private Expression rightExpression; private boolean isMinus; diff --git a/src/main/java/org/rumbledb/expressions/arithmetic/MultiplicativeExpression.java b/src/main/java/org/rumbledb/expressions/arithmetic/MultiplicativeExpression.java index 16423c5a59..50ea315305 100644 --- a/src/main/java/org/rumbledb/expressions/arithmetic/MultiplicativeExpression.java +++ b/src/main/java/org/rumbledb/expressions/arithmetic/MultiplicativeExpression.java @@ -31,6 +31,7 @@ import java.util.List; public class MultiplicativeExpression extends Expression { + private static final long serialVersionUID = 1L; public static enum MultiplicativeOperator { MUL("*"), diff --git a/src/main/java/org/rumbledb/expressions/comparison/ComparisonExpression.java b/src/main/java/org/rumbledb/expressions/comparison/ComparisonExpression.java index 4525c5d4a0..b507ba8c2e 100644 --- a/src/main/java/org/rumbledb/expressions/comparison/ComparisonExpression.java +++ b/src/main/java/org/rumbledb/expressions/comparison/ComparisonExpression.java @@ -31,6 +31,8 @@ public class ComparisonExpression extends Expression { + private static final long serialVersionUID = 1L; + public static enum ComparisonOperator { VC_EQ("eq"), VC_NE("ne"), diff --git a/src/main/java/org/rumbledb/expressions/control/ConditionalExpression.java b/src/main/java/org/rumbledb/expressions/control/ConditionalExpression.java index 81247b76ac..58fab4c5ac 100644 --- a/src/main/java/org/rumbledb/expressions/control/ConditionalExpression.java +++ b/src/main/java/org/rumbledb/expressions/control/ConditionalExpression.java @@ -33,6 +33,7 @@ public class ConditionalExpression extends Expression { + private static final long serialVersionUID = 1L; private final Expression conditionExpression; private final Expression thenExpression; private final Expression elseExpression; diff --git a/src/main/java/org/rumbledb/expressions/flowr/Clause.java b/src/main/java/org/rumbledb/expressions/flowr/Clause.java index 6a35d1d897..12014636eb 100644 --- a/src/main/java/org/rumbledb/expressions/flowr/Clause.java +++ b/src/main/java/org/rumbledb/expressions/flowr/Clause.java @@ -34,6 +34,8 @@ public abstract class Clause extends Node { /* Clauses are organized in doubly-linked lists */ + private static final long serialVersionUID = 1L; + protected Clause previousClause; protected Clause nextClause; protected FLWOR_CLAUSES clauseType; diff --git a/src/main/java/org/rumbledb/expressions/flowr/CountClause.java b/src/main/java/org/rumbledb/expressions/flowr/CountClause.java index 3392464788..3edd67d359 100644 --- a/src/main/java/org/rumbledb/expressions/flowr/CountClause.java +++ b/src/main/java/org/rumbledb/expressions/flowr/CountClause.java @@ -30,6 +30,7 @@ public class CountClause extends Clause { + private static final long serialVersionUID = 1L; private VariableReferenceExpression countClauseVar; public CountClause(VariableReferenceExpression countClauseVar, ExceptionMetadata metadata) { diff --git a/src/main/java/org/rumbledb/expressions/flowr/FlworExpression.java b/src/main/java/org/rumbledb/expressions/flowr/FlworExpression.java index f728b0a023..73c930e8d6 100644 --- a/src/main/java/org/rumbledb/expressions/flowr/FlworExpression.java +++ b/src/main/java/org/rumbledb/expressions/flowr/FlworExpression.java @@ -32,6 +32,7 @@ public class FlworExpression extends Expression { + private static final long serialVersionUID = 1L; private ReturnClause returnClause; public FlworExpression( diff --git a/src/main/java/org/rumbledb/expressions/flowr/ForClause.java b/src/main/java/org/rumbledb/expressions/flowr/ForClause.java index 31a3373a17..bc686114af 100644 --- a/src/main/java/org/rumbledb/expressions/flowr/ForClause.java +++ b/src/main/java/org/rumbledb/expressions/flowr/ForClause.java @@ -36,6 +36,7 @@ public class ForClause extends Clause { + private static final long serialVersionUID = 1L; private final Name variableName; private final boolean allowingEmpty; private final Name positionalVariableName; diff --git a/src/main/java/org/rumbledb/expressions/flowr/GroupByClause.java b/src/main/java/org/rumbledb/expressions/flowr/GroupByClause.java index d4fb0ccd63..3eb1cb3ee7 100644 --- a/src/main/java/org/rumbledb/expressions/flowr/GroupByClause.java +++ b/src/main/java/org/rumbledb/expressions/flowr/GroupByClause.java @@ -32,6 +32,7 @@ public class GroupByClause extends Clause { + private static final long serialVersionUID = 1L; private final List variables; public GroupByClause(List variables, ExceptionMetadata metadata) { diff --git a/src/main/java/org/rumbledb/expressions/flowr/LetClause.java b/src/main/java/org/rumbledb/expressions/flowr/LetClause.java index 8be7d8545d..ec8eebff98 100644 --- a/src/main/java/org/rumbledb/expressions/flowr/LetClause.java +++ b/src/main/java/org/rumbledb/expressions/flowr/LetClause.java @@ -36,6 +36,7 @@ public class LetClause extends Clause { + private static final long serialVersionUID = 1L; private final Name variableName; protected SequenceType sequenceType; protected Expression expression; diff --git a/src/main/java/org/rumbledb/expressions/flowr/OrderByClause.java b/src/main/java/org/rumbledb/expressions/flowr/OrderByClause.java index 5d892a177f..b5bb218af1 100644 --- a/src/main/java/org/rumbledb/expressions/flowr/OrderByClause.java +++ b/src/main/java/org/rumbledb/expressions/flowr/OrderByClause.java @@ -31,6 +31,7 @@ public class OrderByClause extends Clause { + private static final long serialVersionUID = 1L; private final List sortingKeys; private final boolean isStable; diff --git a/src/main/java/org/rumbledb/expressions/flowr/ReturnClause.java b/src/main/java/org/rumbledb/expressions/flowr/ReturnClause.java index 5753a05840..f03e2420b9 100644 --- a/src/main/java/org/rumbledb/expressions/flowr/ReturnClause.java +++ b/src/main/java/org/rumbledb/expressions/flowr/ReturnClause.java @@ -32,6 +32,7 @@ public class ReturnClause extends Clause { + private static final long serialVersionUID = 1L; private final Expression returnExpr; diff --git a/src/main/java/org/rumbledb/expressions/flowr/SimpleMapExpression.java b/src/main/java/org/rumbledb/expressions/flowr/SimpleMapExpression.java index 91ef9bc425..9d5a6ffc2f 100644 --- a/src/main/java/org/rumbledb/expressions/flowr/SimpleMapExpression.java +++ b/src/main/java/org/rumbledb/expressions/flowr/SimpleMapExpression.java @@ -31,6 +31,7 @@ import java.util.List; public class SimpleMapExpression extends Expression { + private static final long serialVersionUID = 1L; private Expression leftExpression; private Expression rightExpression; diff --git a/src/main/java/org/rumbledb/expressions/logic/AndExpression.java b/src/main/java/org/rumbledb/expressions/logic/AndExpression.java index 82e8759ea9..74999d6960 100644 --- a/src/main/java/org/rumbledb/expressions/logic/AndExpression.java +++ b/src/main/java/org/rumbledb/expressions/logic/AndExpression.java @@ -29,6 +29,8 @@ import java.util.List; public class AndExpression extends Expression { + private static final long serialVersionUID = 1L; + private Expression leftExpression; private Expression rightExpression; diff --git a/src/main/java/org/rumbledb/expressions/logic/NotExpression.java b/src/main/java/org/rumbledb/expressions/logic/NotExpression.java index 89a1d3a849..bea2dfe51e 100644 --- a/src/main/java/org/rumbledb/expressions/logic/NotExpression.java +++ b/src/main/java/org/rumbledb/expressions/logic/NotExpression.java @@ -32,6 +32,7 @@ public class NotExpression extends Expression { + private static final long serialVersionUID = 1L; private Expression mainExpression; public NotExpression(Expression mainExpression, ExceptionMetadata metadata) { diff --git a/src/main/java/org/rumbledb/expressions/logic/OrExpression.java b/src/main/java/org/rumbledb/expressions/logic/OrExpression.java index 253815c7e5..28bc1340f6 100644 --- a/src/main/java/org/rumbledb/expressions/logic/OrExpression.java +++ b/src/main/java/org/rumbledb/expressions/logic/OrExpression.java @@ -29,6 +29,7 @@ import java.util.List; public class OrExpression extends Expression { + private static final long serialVersionUID = 1L; private Expression leftExpression; private Expression rightExpression; diff --git a/src/main/java/org/rumbledb/expressions/miscellaneous/RangeExpression.java b/src/main/java/org/rumbledb/expressions/miscellaneous/RangeExpression.java index c8b0a71ee4..d22ac08070 100644 --- a/src/main/java/org/rumbledb/expressions/miscellaneous/RangeExpression.java +++ b/src/main/java/org/rumbledb/expressions/miscellaneous/RangeExpression.java @@ -31,6 +31,7 @@ public class RangeExpression extends Expression { + private static final long serialVersionUID = 1L; private Expression leftExpression; private Expression rightExpression; diff --git a/src/main/java/org/rumbledb/expressions/miscellaneous/StringConcatExpression.java b/src/main/java/org/rumbledb/expressions/miscellaneous/StringConcatExpression.java index e9b19573c0..ae23b440e6 100644 --- a/src/main/java/org/rumbledb/expressions/miscellaneous/StringConcatExpression.java +++ b/src/main/java/org/rumbledb/expressions/miscellaneous/StringConcatExpression.java @@ -30,6 +30,7 @@ import java.util.List; public class StringConcatExpression extends Expression { + private static final long serialVersionUID = 1L; private Expression leftExpression; private Expression rightExpression; diff --git a/src/main/java/org/rumbledb/expressions/module/FunctionDeclaration.java b/src/main/java/org/rumbledb/expressions/module/FunctionDeclaration.java index 0eace7ff8d..851ea75f9f 100644 --- a/src/main/java/org/rumbledb/expressions/module/FunctionDeclaration.java +++ b/src/main/java/org/rumbledb/expressions/module/FunctionDeclaration.java @@ -34,6 +34,7 @@ public class FunctionDeclaration extends Node { + private static final long serialVersionUID = 1L; private final InlineFunctionExpression functionExpression; public FunctionDeclaration( diff --git a/src/main/java/org/rumbledb/expressions/module/LibraryModule.java b/src/main/java/org/rumbledb/expressions/module/LibraryModule.java index 08c4ca62e2..3c4c12bf89 100644 --- a/src/main/java/org/rumbledb/expressions/module/LibraryModule.java +++ b/src/main/java/org/rumbledb/expressions/module/LibraryModule.java @@ -31,6 +31,7 @@ public class LibraryModule extends Module { + private static final long serialVersionUID = 1L; protected StaticContext staticContext; private String namespace; private final Prolog prolog; diff --git a/src/main/java/org/rumbledb/expressions/module/MainModule.java b/src/main/java/org/rumbledb/expressions/module/MainModule.java index 5e1904ad39..892a8aed4f 100644 --- a/src/main/java/org/rumbledb/expressions/module/MainModule.java +++ b/src/main/java/org/rumbledb/expressions/module/MainModule.java @@ -33,6 +33,7 @@ public class MainModule extends Module { + private static final long serialVersionUID = 1L; protected StaticContext staticContext; private final Prolog prolog; private final Expression expression; diff --git a/src/main/java/org/rumbledb/expressions/module/Module.java b/src/main/java/org/rumbledb/expressions/module/Module.java index c35d1ae1a4..3213f21ebb 100644 --- a/src/main/java/org/rumbledb/expressions/module/Module.java +++ b/src/main/java/org/rumbledb/expressions/module/Module.java @@ -5,6 +5,8 @@ import org.rumbledb.expressions.Node; public abstract class Module extends Node { + private static final long serialVersionUID = 1L; + public Module(ExceptionMetadata metadata) { super(metadata); } diff --git a/src/main/java/org/rumbledb/expressions/module/Prolog.java b/src/main/java/org/rumbledb/expressions/module/Prolog.java index 1004b1367e..81331be346 100644 --- a/src/main/java/org/rumbledb/expressions/module/Prolog.java +++ b/src/main/java/org/rumbledb/expressions/module/Prolog.java @@ -31,6 +31,7 @@ public class Prolog extends Node { + private static final long serialVersionUID = 1L; private List declarations; private List importedModules; diff --git a/src/main/java/org/rumbledb/expressions/postfix/ArrayLookupExpression.java b/src/main/java/org/rumbledb/expressions/postfix/ArrayLookupExpression.java index 64257cbf2d..9cf5a2e1fd 100644 --- a/src/main/java/org/rumbledb/expressions/postfix/ArrayLookupExpression.java +++ b/src/main/java/org/rumbledb/expressions/postfix/ArrayLookupExpression.java @@ -33,6 +33,8 @@ public class ArrayLookupExpression extends Expression { + private static final long serialVersionUID = 1L; + private Expression mainExpression; private Expression lookupExpression; diff --git a/src/main/java/org/rumbledb/expressions/postfix/ArrayUnboxingExpression.java b/src/main/java/org/rumbledb/expressions/postfix/ArrayUnboxingExpression.java index 1997befd7e..f71b499c82 100644 --- a/src/main/java/org/rumbledb/expressions/postfix/ArrayUnboxingExpression.java +++ b/src/main/java/org/rumbledb/expressions/postfix/ArrayUnboxingExpression.java @@ -33,6 +33,8 @@ public class ArrayUnboxingExpression extends Expression { + private static final long serialVersionUID = 1L; + private Expression mainExpression; public ArrayUnboxingExpression(Expression mainExpression, ExceptionMetadata metadata) { diff --git a/src/main/java/org/rumbledb/expressions/postfix/DynamicFunctionCallExpression.java b/src/main/java/org/rumbledb/expressions/postfix/DynamicFunctionCallExpression.java index d2aa9cf11e..a36afa3132 100644 --- a/src/main/java/org/rumbledb/expressions/postfix/DynamicFunctionCallExpression.java +++ b/src/main/java/org/rumbledb/expressions/postfix/DynamicFunctionCallExpression.java @@ -20,9 +20,11 @@ package org.rumbledb.expressions.postfix; +import org.rumbledb.compiler.VisitorConfig; import org.rumbledb.exceptions.ExceptionMetadata; import org.rumbledb.exceptions.OurBadException; import org.rumbledb.expressions.AbstractNodeVisitor; +import org.rumbledb.expressions.ExecutionMode; import org.rumbledb.expressions.Expression; import org.rumbledb.expressions.Node; import java.util.ArrayList; @@ -31,6 +33,7 @@ public class DynamicFunctionCallExpression extends Expression { + private static final long serialVersionUID = 1L; private Expression mainExpression; private List arguments; @@ -115,4 +118,13 @@ public void serializeToJSONiq(StringBuffer sb, int indent) { } sb.append(")\n"); } + + @Override + public void initHighestExecutionMode(VisitorConfig visitorConfig) { + if (this.arguments.size() == 0) { + this.highestExecutionMode = ExecutionMode.LOCAL; + return; + } + this.highestExecutionMode = this.arguments.get(0).getHighestExecutionMode(visitorConfig); + } } diff --git a/src/main/java/org/rumbledb/expressions/postfix/FilterExpression.java b/src/main/java/org/rumbledb/expressions/postfix/FilterExpression.java index 0c700186d2..52b72abe36 100644 --- a/src/main/java/org/rumbledb/expressions/postfix/FilterExpression.java +++ b/src/main/java/org/rumbledb/expressions/postfix/FilterExpression.java @@ -37,6 +37,7 @@ public class FilterExpression extends Expression { + private static final long serialVersionUID = 1L; private Expression mainExpression; private Expression predicateExpression; diff --git a/src/main/java/org/rumbledb/expressions/postfix/ObjectLookupExpression.java b/src/main/java/org/rumbledb/expressions/postfix/ObjectLookupExpression.java index 6e17c5eaec..02cc16461c 100644 --- a/src/main/java/org/rumbledb/expressions/postfix/ObjectLookupExpression.java +++ b/src/main/java/org/rumbledb/expressions/postfix/ObjectLookupExpression.java @@ -33,6 +33,7 @@ public class ObjectLookupExpression extends Expression { + private static final long serialVersionUID = 1L; private Expression mainExpression; private Expression lookupExpression; diff --git a/src/main/java/org/rumbledb/expressions/primary/ArrayConstructorExpression.java b/src/main/java/org/rumbledb/expressions/primary/ArrayConstructorExpression.java index b6da96d490..51eb88cdf0 100644 --- a/src/main/java/org/rumbledb/expressions/primary/ArrayConstructorExpression.java +++ b/src/main/java/org/rumbledb/expressions/primary/ArrayConstructorExpression.java @@ -31,6 +31,8 @@ public class ArrayConstructorExpression extends Expression { + private static final long serialVersionUID = 1L; + private Expression expression; public ArrayConstructorExpression(Expression expression, ExceptionMetadata metadata) { diff --git a/src/main/java/org/rumbledb/expressions/primary/BooleanLiteralExpression.java b/src/main/java/org/rumbledb/expressions/primary/BooleanLiteralExpression.java index b276978f98..5e23d7e15f 100644 --- a/src/main/java/org/rumbledb/expressions/primary/BooleanLiteralExpression.java +++ b/src/main/java/org/rumbledb/expressions/primary/BooleanLiteralExpression.java @@ -31,6 +31,8 @@ public class BooleanLiteralExpression extends Expression { + private static final long serialVersionUID = 1L; + private boolean value; public BooleanLiteralExpression(boolean value, ExceptionMetadata metadata) { diff --git a/src/main/java/org/rumbledb/expressions/primary/ContextItemExpression.java b/src/main/java/org/rumbledb/expressions/primary/ContextItemExpression.java index 65e0119283..04a55e54c0 100644 --- a/src/main/java/org/rumbledb/expressions/primary/ContextItemExpression.java +++ b/src/main/java/org/rumbledb/expressions/primary/ContextItemExpression.java @@ -30,6 +30,8 @@ public class ContextItemExpression extends Expression { + private static final long serialVersionUID = 1L; + public ContextItemExpression(ExceptionMetadata metadataFromContext) { super(metadataFromContext); } diff --git a/src/main/java/org/rumbledb/expressions/primary/DecimalLiteralExpression.java b/src/main/java/org/rumbledb/expressions/primary/DecimalLiteralExpression.java index b1c7e35589..94ffdf141a 100644 --- a/src/main/java/org/rumbledb/expressions/primary/DecimalLiteralExpression.java +++ b/src/main/java/org/rumbledb/expressions/primary/DecimalLiteralExpression.java @@ -32,6 +32,7 @@ public class DecimalLiteralExpression extends Expression { + private static final long serialVersionUID = 1L; private BigDecimal value; public DecimalLiteralExpression(BigDecimal value, ExceptionMetadata metadata) { diff --git a/src/main/java/org/rumbledb/expressions/primary/DoubleLiteralExpression.java b/src/main/java/org/rumbledb/expressions/primary/DoubleLiteralExpression.java index 3db15cc36a..4f0c24ec2e 100644 --- a/src/main/java/org/rumbledb/expressions/primary/DoubleLiteralExpression.java +++ b/src/main/java/org/rumbledb/expressions/primary/DoubleLiteralExpression.java @@ -31,6 +31,7 @@ public class DoubleLiteralExpression extends Expression { + private static final long serialVersionUID = 1L; private double value; public DoubleLiteralExpression(double value, ExceptionMetadata metadata) { diff --git a/src/main/java/org/rumbledb/expressions/primary/FunctionCallExpression.java b/src/main/java/org/rumbledb/expressions/primary/FunctionCallExpression.java index ceb14792fc..e7a15e3401 100644 --- a/src/main/java/org/rumbledb/expressions/primary/FunctionCallExpression.java +++ b/src/main/java/org/rumbledb/expressions/primary/FunctionCallExpression.java @@ -40,6 +40,7 @@ public class FunctionCallExpression extends Expression { + private static final long serialVersionUID = 1L; private final FunctionIdentifier identifier; private final List arguments; // null for placeholder private final boolean isPartialApplication; diff --git a/src/main/java/org/rumbledb/expressions/primary/InlineFunctionExpression.java b/src/main/java/org/rumbledb/expressions/primary/InlineFunctionExpression.java index e37f19c9dc..384ae1ac5d 100644 --- a/src/main/java/org/rumbledb/expressions/primary/InlineFunctionExpression.java +++ b/src/main/java/org/rumbledb/expressions/primary/InlineFunctionExpression.java @@ -21,6 +21,7 @@ package org.rumbledb.expressions.primary; +import org.apache.commons.lang3.SerializationUtils; import org.rumbledb.compiler.VisitorConfig; import org.rumbledb.context.FunctionIdentifier; import org.rumbledb.context.Name; @@ -30,21 +31,24 @@ import org.rumbledb.expressions.Node; import org.rumbledb.types.SequenceType; -import java.util.Arrays; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; public class InlineFunctionExpression extends Expression { + private static final long serialVersionUID = 1L; private final Name name; private final FunctionIdentifier functionIdentifier; - private final Map params; + private final LinkedHashMap params; private final SequenceType returnType; - private final Expression body; + private final Map bodies; public InlineFunctionExpression( Name name, - Map params, + LinkedHashMap params, SequenceType returnType, Expression body, ExceptionMetadata metadata @@ -53,7 +57,12 @@ public InlineFunctionExpression( this.name = name; this.params = params; this.returnType = returnType; - this.body = body; + this.bodies = new HashMap<>(); + this.bodies.put(0L, body); + // for inline functions, we maintain another version for ML models + if (this.name == null) { + this.bodies.put(1L, (Expression) SerializationUtils.clone(body)); + } this.functionIdentifier = new FunctionIdentifier(name, params.size()); } @@ -78,12 +87,12 @@ public SequenceType getActualReturnType() { } public Expression getBody() { - return this.body; + return this.bodies.get(0L); } @Override public List getChildren() { - return Arrays.asList(this.body); + return new ArrayList<>(this.bodies.values()); } public void registerUserDefinedFunctionExecutionMode( @@ -95,7 +104,7 @@ public void registerUserDefinedFunctionExecutionMode( getStaticContext().getUserDefinedFunctionsExecutionModes() .setExecutionMode( identifier, - this.body.getHighestExecutionMode(visitorConfig), + this.bodies.get(0L).getHighestExecutionMode(visitorConfig), visitorConfig.suppressErrorsForFunctionSignatureCollision(), this.getMetadata() ); @@ -124,11 +133,13 @@ public void print(StringBuffer buffer, int indent) { buffer.append(" | " + this.highestExecutionMode); buffer.append(" | " + (this.inferredSequenceType == null ? "not set" : this.inferredSequenceType)); buffer.append("\n"); - for (int i = 0; i < indent + 2; ++i) { - buffer.append(" "); + for (long l : this.bodies.keySet()) { + for (int i = 0; i < indent + 1; ++i) { + buffer.append(" "); + } + buffer.append("Body " + l + ":\n"); + this.bodies.get(l).print(buffer, indent + 2); } - buffer.append("Body:\n"); - this.body.print(buffer, indent + 2); } @Override @@ -157,10 +168,14 @@ public void serializeToJSONiq(StringBuffer sb, int indent) { sb.append("\n"); indentIt(sb, indent); sb.append("{\n"); - this.body.serializeToJSONiq(sb, indent + 1); + this.bodies.get(0L).serializeToJSONiq(sb, indent + 1); indentIt(sb, indent); sb.append("}\n"); } } + + public Map getBodies() { + return this.bodies; + } } diff --git a/src/main/java/org/rumbledb/expressions/primary/IntegerLiteralExpression.java b/src/main/java/org/rumbledb/expressions/primary/IntegerLiteralExpression.java index fcbf9a639f..1c507d88d0 100644 --- a/src/main/java/org/rumbledb/expressions/primary/IntegerLiteralExpression.java +++ b/src/main/java/org/rumbledb/expressions/primary/IntegerLiteralExpression.java @@ -31,6 +31,7 @@ public class IntegerLiteralExpression extends Expression { + private static final long serialVersionUID = 1L; private String lexicalValue; public IntegerLiteralExpression(String lexicalValue, ExceptionMetadata metadata) { diff --git a/src/main/java/org/rumbledb/expressions/primary/NamedFunctionReferenceExpression.java b/src/main/java/org/rumbledb/expressions/primary/NamedFunctionReferenceExpression.java index b3342992c3..8dae35fb7a 100644 --- a/src/main/java/org/rumbledb/expressions/primary/NamedFunctionReferenceExpression.java +++ b/src/main/java/org/rumbledb/expressions/primary/NamedFunctionReferenceExpression.java @@ -32,6 +32,7 @@ public class NamedFunctionReferenceExpression extends Expression { + private static final long serialVersionUID = 1L; private final FunctionIdentifier identifier; public NamedFunctionReferenceExpression(FunctionIdentifier identifier, ExceptionMetadata metadata) { diff --git a/src/main/java/org/rumbledb/expressions/primary/NullLiteralExpression.java b/src/main/java/org/rumbledb/expressions/primary/NullLiteralExpression.java index 0b7a52808e..e508119600 100644 --- a/src/main/java/org/rumbledb/expressions/primary/NullLiteralExpression.java +++ b/src/main/java/org/rumbledb/expressions/primary/NullLiteralExpression.java @@ -32,6 +32,8 @@ public class NullLiteralExpression extends Expression { + private static final long serialVersionUID = 1L; + public NullLiteralExpression(ExceptionMetadata metadata) { super(metadata); } diff --git a/src/main/java/org/rumbledb/expressions/primary/ObjectConstructorExpression.java b/src/main/java/org/rumbledb/expressions/primary/ObjectConstructorExpression.java index 429baf5a63..c148267d47 100644 --- a/src/main/java/org/rumbledb/expressions/primary/ObjectConstructorExpression.java +++ b/src/main/java/org/rumbledb/expressions/primary/ObjectConstructorExpression.java @@ -31,6 +31,7 @@ public class ObjectConstructorExpression extends Expression { + private static final long serialVersionUID = 1L; private boolean isMergedConstructor = false; private List values; private List keys; diff --git a/src/main/java/org/rumbledb/expressions/typing/CastExpression.java b/src/main/java/org/rumbledb/expressions/typing/CastExpression.java index 6dd9e10bbe..5fea4a2e06 100644 --- a/src/main/java/org/rumbledb/expressions/typing/CastExpression.java +++ b/src/main/java/org/rumbledb/expressions/typing/CastExpression.java @@ -13,6 +13,8 @@ public class CastExpression extends Expression { + private static final long serialVersionUID = 1L; + private Expression mainExpression; private SequenceType sequenceType; diff --git a/src/main/java/org/rumbledb/expressions/typing/CastableExpression.java b/src/main/java/org/rumbledb/expressions/typing/CastableExpression.java index d790fa91cf..8b3fe0303c 100644 --- a/src/main/java/org/rumbledb/expressions/typing/CastableExpression.java +++ b/src/main/java/org/rumbledb/expressions/typing/CastableExpression.java @@ -13,6 +13,8 @@ public class CastableExpression extends Expression { + private static final long serialVersionUID = 1L; + protected Expression mainExpression; private SequenceType sequenceType; diff --git a/src/main/java/org/rumbledb/expressions/typing/InstanceOfExpression.java b/src/main/java/org/rumbledb/expressions/typing/InstanceOfExpression.java index ef730359a8..30545cc0a8 100644 --- a/src/main/java/org/rumbledb/expressions/typing/InstanceOfExpression.java +++ b/src/main/java/org/rumbledb/expressions/typing/InstanceOfExpression.java @@ -33,6 +33,7 @@ public class InstanceOfExpression extends Expression { + private static final long serialVersionUID = 1L; private Expression mainExpression; private SequenceType sequenceType; diff --git a/src/main/java/org/rumbledb/expressions/typing/IsStaticallyExpression.java b/src/main/java/org/rumbledb/expressions/typing/IsStaticallyExpression.java index cad37ad1d1..449bcbdb03 100644 --- a/src/main/java/org/rumbledb/expressions/typing/IsStaticallyExpression.java +++ b/src/main/java/org/rumbledb/expressions/typing/IsStaticallyExpression.java @@ -11,6 +11,7 @@ import java.util.List; public class IsStaticallyExpression extends Expression { + private static final long serialVersionUID = 1L; private Expression mainExpression; private SequenceType sequenceType; diff --git a/src/main/java/org/rumbledb/items/FunctionItem.java b/src/main/java/org/rumbledb/items/FunctionItem.java index 5a219b04c6..2d6a9cd652 100644 --- a/src/main/java/org/rumbledb/items/FunctionItem.java +++ b/src/main/java/org/rumbledb/items/FunctionItem.java @@ -58,7 +58,7 @@ public class FunctionItem implements Item { // signature contains type information for all parameters and the return value private FunctionSignature signature; - private RuntimeIterator bodyIterator; + private Map bodyIterators; private DynamicContext dynamicModuleContext; private Map> localVariablesInClosure; private Map> RDDVariablesInClosure; @@ -73,12 +73,12 @@ public FunctionItem( List parameterNames, FunctionSignature signature, DynamicContext dynamicModuleContext, - RuntimeIterator bodyIterator + Map bodyIterators ) { this.identifier = identifier; this.parameterNames = parameterNames; this.signature = signature; - this.bodyIterator = bodyIterator; + this.bodyIterators = bodyIterators; this.dynamicModuleContext = dynamicModuleContext; this.localVariablesInClosure = new HashMap<>(); this.RDDVariablesInClosure = new HashMap<>(); @@ -90,7 +90,7 @@ public FunctionItem( List parameterNames, FunctionSignature signature, DynamicContext dynamicModuleContext, - RuntimeIterator bodyIterator, + Map bodyIterators, Map> localVariablesInClosure, Map> RDDVariablesInClosure, Map DFVariablesInClosure @@ -98,7 +98,7 @@ public FunctionItem( this.identifier = identifier; this.parameterNames = parameterNames; this.signature = signature; - this.bodyIterator = bodyIterator; + this.bodyIterators = bodyIterators; this.dynamicModuleContext = dynamicModuleContext; this.localVariablesInClosure = localVariablesInClosure; this.RDDVariablesInClosure = RDDVariablesInClosure; @@ -110,7 +110,7 @@ public FunctionItem( Map paramNameToSequenceTypes, SequenceType returnType, DynamicContext dynamicModuleContext, - RuntimeIterator bodyIterator + Map bodyIterators ) { List paramNames = new ArrayList<>(); List parameters = new ArrayList<>(); @@ -122,7 +122,7 @@ public FunctionItem( this.identifier = new FunctionIdentifier(name, paramNames.size()); this.parameterNames = paramNames; this.signature = new FunctionSignature(parameters, returnType); - this.bodyIterator = bodyIterator; + this.bodyIterators = bodyIterators; this.dynamicModuleContext = dynamicModuleContext; this.localVariablesInClosure = new HashMap<>(); this.RDDVariablesInClosure = new HashMap<>(); @@ -149,7 +149,11 @@ public DynamicContext getModuleDynamicContext() { } public RuntimeIterator getBodyIterator() { - return this.bodyIterator; + return this.bodyIterators.get(0L); + } + + public Map getBodyIterators() { + return this.bodyIterators; } public Map> getLocalVariablesInClosure() { @@ -200,7 +204,9 @@ public String toString() { sb.append(param + " "); } sb.append("Signature: " + this.signature + "\n"); - sb.append("Body:\n" + this.bodyIterator + "\n"); + for (long l : this.bodyIterators.keySet()) { + sb.append("Body " + l + ":\n" + this.bodyIterators.get(l) + "\n"); + } sb.append("Closure:\n"); sb.append(" Local:\n"); for (Name name : this.localVariablesInClosure.keySet()) { @@ -243,7 +249,7 @@ public void write(Kryo kryo, Output output) { try { ByteArrayOutputStream bos = new ByteArrayOutputStream(); ObjectOutputStream oos = new ObjectOutputStream(bos); - oos.writeObject(this.bodyIterator); + oos.writeObject(this.bodyIterators); oos.flush(); byte[] data = bos.toByteArray(); output.writeInt(data.length); @@ -274,7 +280,7 @@ public void read(Kryo kryo, Input input) { byte[] data = input.readBytes(dataLength); ByteArrayInputStream bis = new ByteArrayInputStream(data); ObjectInputStream ois = new ObjectInputStream(bis); - this.bodyIterator = (RuntimeIterator) ois.readObject(); + this.bodyIterators = (Map) ois.readObject(); } catch (Exception e) { throw new OurBadException( "Error converting functionItem-bodyRuntimeIterator to functionItem:" + e.getMessage() diff --git a/src/main/java/org/rumbledb/runtime/functions/DynamicFunctionCallIterator.java b/src/main/java/org/rumbledb/runtime/functions/DynamicFunctionCallIterator.java index ec07f74a5e..4b5043e945 100644 --- a/src/main/java/org/rumbledb/runtime/functions/DynamicFunctionCallIterator.java +++ b/src/main/java/org/rumbledb/runtime/functions/DynamicFunctionCallIterator.java @@ -20,20 +20,23 @@ package org.rumbledb.runtime.functions; +import org.apache.spark.api.java.JavaRDD; import org.rumbledb.api.Item; import org.rumbledb.context.DynamicContext; import org.rumbledb.context.NamedFunctions; import org.rumbledb.exceptions.ExceptionMetadata; import org.rumbledb.exceptions.IteratorFlowException; import org.rumbledb.exceptions.MoreThanOneItemException; +import org.rumbledb.exceptions.OurBadException; import org.rumbledb.exceptions.UnexpectedTypeException; import org.rumbledb.expressions.ExecutionMode; -import org.rumbledb.runtime.LocalRuntimeIterator; +import org.rumbledb.items.structured.JSoundDataFrame; +import org.rumbledb.runtime.HybridRuntimeIterator; import org.rumbledb.runtime.RuntimeIterator; import java.util.List; -public class DynamicFunctionCallIterator extends LocalRuntimeIterator { +public class DynamicFunctionCallIterator extends HybridRuntimeIterator { // dynamic: functionIdentifier is not known at compile time // it is known only after evaluating postfix expression at runtime @@ -67,15 +70,19 @@ public DynamicFunctionCallIterator( } @Override - public void open(DynamicContext context) { - super.open(context); - setFunctionItemAndIteratorWithCurrentContext(); + public void openLocal() { + setFunctionItemAndIteratorWithCurrentContext(this.currentDynamicContextForLocalExecution); this.functionCallIterator.open(this.currentDynamicContextForLocalExecution); setNextResult(); } @Override - public Item next() { + protected boolean hasNextLocal() { + return this.hasNext; + } + + @Override + public Item nextLocal() { if (this.hasNext) { Item result = this.nextResult; setNextResult(); @@ -103,11 +110,9 @@ public void setNextResult() { } } - private void setFunctionItemAndIteratorWithCurrentContext() { + private void setFunctionItemAndIteratorWithCurrentContext(DynamicContext context) { try { - this.functionItem = this.functionItemIterator.materializeAtMostOneItemOrNull( - this.currentDynamicContextForLocalExecution - ); + this.functionItem = this.functionItemIterator.materializeAtMostOneItemOrNull(context); } catch (MoreThanOneItemException e) { throw new UnexpectedTypeException( "A dynamic function call can not be performed on a sequence of more than one item.", @@ -120,6 +125,14 @@ private void setFunctionItemAndIteratorWithCurrentContext() { getMetadata() ); } + if (!this.functionItem.getBodyIterator().getHighestExecutionMode().equals(this.getHighestExecutionMode())) { + throw new OurBadException( + "Execution mode mismatch in dynamic function call: expression expects " + + this.getHighestExecutionMode() + + " but function item expects " + + this.functionItem.getBodyIterator().getHighestExecutionMode() + ); + } this.functionCallIterator = NamedFunctions.buildUserDefinedFunctionCallIterator( this.functionItem, this.functionItem.getBodyIterator().getHighestExecutionMode(), @@ -129,19 +142,29 @@ private void setFunctionItemAndIteratorWithCurrentContext() { } @Override - public void reset(DynamicContext context) { - super.reset(context); + public void resetLocal() { this.functionCallIterator.reset(this.currentDynamicContextForLocalExecution); setNextResult(); } @Override - public void close() { + public void closeLocal() { // ensure that recursive function calls terminate gracefully // the function call in the body of the deepest recursion call is never visited, never opened and never closed if (this.isOpen) { this.functionCallIterator.close(); } - super.close(); + } + + @Override + public JavaRDD getRDDAux(DynamicContext dynamicContext) { + setFunctionItemAndIteratorWithCurrentContext(dynamicContext); + return this.functionCallIterator.getRDD(dynamicContext); + } + + @Override + public JSoundDataFrame getDataFrame(DynamicContext dynamicContext) { + setFunctionItemAndIteratorWithCurrentContext(dynamicContext); + return this.functionCallIterator.getDataFrame(dynamicContext); } } diff --git a/src/main/java/org/rumbledb/runtime/functions/FunctionItemCallIterator.java b/src/main/java/org/rumbledb/runtime/functions/FunctionItemCallIterator.java index 88241894dd..92d348079d 100644 --- a/src/main/java/org/rumbledb/runtime/functions/FunctionItemCallIterator.java +++ b/src/main/java/org/rumbledb/runtime/functions/FunctionItemCallIterator.java @@ -232,7 +232,7 @@ private RuntimeIterator generatePartiallyAppliedFunction(DynamicContext context) this.functionItem.getSignature().getReturnType() ), this.functionItem.getModuleDynamicContext(), - this.functionItem.getBodyIterator(), + this.functionItem.getBodyIterators(), localArgumentValues, RDDArgumentValues, DFArgumentValues diff --git a/src/main/java/org/rumbledb/runtime/functions/FunctionRuntimeIterator.java b/src/main/java/org/rumbledb/runtime/functions/FunctionRuntimeIterator.java index f57ad0206a..6d5c60fcdc 100644 --- a/src/main/java/org/rumbledb/runtime/functions/FunctionRuntimeIterator.java +++ b/src/main/java/org/rumbledb/runtime/functions/FunctionRuntimeIterator.java @@ -20,6 +20,7 @@ package org.rumbledb.runtime.functions; +import java.util.HashMap; import java.util.Map; import org.rumbledb.api.Item; @@ -38,13 +39,13 @@ public class FunctionRuntimeIterator extends AtMostOneItemLocalRuntimeIterator { private Name functionName; private Map paramNameToSequenceTypes; SequenceType returnType; - RuntimeIterator bodyIterator; + Map bodyIterators; public FunctionRuntimeIterator( Name functionName, Map paramNameToSequenceTypes, SequenceType returnType, - RuntimeIterator bodyIterator, + Map bodyIterators, ExecutionMode executionMode, ExceptionMetadata iteratorMetadata ) { @@ -52,18 +53,21 @@ public FunctionRuntimeIterator( this.functionName = functionName; this.paramNameToSequenceTypes = paramNameToSequenceTypes; this.returnType = returnType; - this.bodyIterator = bodyIterator; + this.bodyIterators = bodyIterators; } @Override public Item materializeFirstItemOrNull(DynamicContext dynamicContext) { - RuntimeIterator bodyIteratorCopy = ((RuntimeIterator) this.bodyIterator).deepCopy(); + Map bodyIteratorsCopy = new HashMap<>(); + for (long l : this.bodyIterators.keySet()) { + bodyIteratorsCopy.put(l, (RuntimeIterator) this.bodyIterators.get(l).deepCopy()); + } FunctionItem function = new FunctionItem( this.functionName, this.paramNameToSequenceTypes, this.returnType, dynamicContext.getModuleContext(), - bodyIteratorCopy + bodyIteratorsCopy ); function.populateClosureFromDynamicContext(dynamicContext, getMetadata()); return function; diff --git a/src/main/java/sparksoniq/spark/ml/ApplyEstimatorRuntimeIterator.java b/src/main/java/sparksoniq/spark/ml/ApplyEstimatorRuntimeIterator.java index 7df04db024..91b37fc8cd 100644 --- a/src/main/java/sparksoniq/spark/ml/ApplyEstimatorRuntimeIterator.java +++ b/src/main/java/sparksoniq/spark/ml/ApplyEstimatorRuntimeIterator.java @@ -27,7 +27,9 @@ import java.lang.reflect.InvocationTargetException; import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.NoSuchElementException; import static sparksoniq.spark.ml.RumbleMLUtils.convertRumbleObjectItemToSparkMLParamMap; @@ -232,11 +234,24 @@ private void setSparkMLEstimatorParamToValue(String paramName, String value) { } private Item generateTransformerFunctionItem(Transformer fittedModel) { - RuntimeIterator bodyIterator = new ApplyTransformerRuntimeIterator( - RumbleMLCatalog.getRumbleMLShortName(fittedModel.getClass().getName()), - fittedModel, - ExecutionMode.DATAFRAME, - getMetadata() + Map bodyIterators = new HashMap<>(); + bodyIterators.put( + 0L, + new ApplyTransformerRuntimeIterator( + RumbleMLCatalog.getRumbleMLShortName(fittedModel.getClass().getName()), + fittedModel, + ExecutionMode.LOCAL, + getMetadata() + ) + ); + bodyIterators.put( + 1L, + new ApplyTransformerRuntimeIterator( + RumbleMLCatalog.getRumbleMLShortName(fittedModel.getClass().getName()), + fittedModel, + ExecutionMode.DATAFRAME, + getMetadata() + ) ); List paramTypes = Collections.unmodifiableList( Arrays.asList( @@ -266,7 +281,7 @@ private Item generateTransformerFunctionItem(Transformer fittedModel) { returnType ), new DynamicContext(this.currentDynamicContextForLocalExecution.getRumbleRuntimeConfiguration()), - bodyIterator + bodyIterators ); } } diff --git a/src/main/java/sparksoniq/spark/ml/GetEstimatorFunctionIterator.java b/src/main/java/sparksoniq/spark/ml/GetEstimatorFunctionIterator.java index 011ba47834..33c25dc2aa 100644 --- a/src/main/java/sparksoniq/spark/ml/GetEstimatorFunctionIterator.java +++ b/src/main/java/sparksoniq/spark/ml/GetEstimatorFunctionIterator.java @@ -40,7 +40,9 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; public class GetEstimatorFunctionIterator extends LocalFunctionCallIterator { @@ -107,11 +109,24 @@ public Item next() { this.hasNext = false; try { Estimator estimator = (Estimator) this.estimatorSparkMLClass.newInstance(); - RuntimeIterator bodyIterator = new ApplyEstimatorRuntimeIterator( - this.estimatorShortName, - estimator, - ExecutionMode.LOCAL, - getMetadata() + Map bodyIterators = new HashMap<>(); + bodyIterators.put( + 0L, + new ApplyEstimatorRuntimeIterator( + this.estimatorShortName, + estimator, + ExecutionMode.LOCAL, + getMetadata() + ) + ); + bodyIterators.put( + 1L, + new ApplyEstimatorRuntimeIterator( + this.estimatorShortName, + estimator, + ExecutionMode.DATAFRAME, + getMetadata() + ) ); List paramTypes = Collections.unmodifiableList( Arrays.asList( @@ -143,7 +158,7 @@ public Item next() { returnType ), new DynamicContext(this.currentDynamicContextForLocalExecution.getRumbleRuntimeConfiguration()), - bodyIterator + bodyIterators ); } catch (InstantiationException | IllegalAccessException e) { diff --git a/src/main/java/sparksoniq/spark/ml/GetTransformerFunctionIterator.java b/src/main/java/sparksoniq/spark/ml/GetTransformerFunctionIterator.java index c12ab7b965..7adb46f784 100644 --- a/src/main/java/sparksoniq/spark/ml/GetTransformerFunctionIterator.java +++ b/src/main/java/sparksoniq/spark/ml/GetTransformerFunctionIterator.java @@ -40,7 +40,9 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; public class GetTransformerFunctionIterator extends LocalFunctionCallIterator { @@ -107,11 +109,24 @@ public Item next() { this.hasNext = false; try { Transformer transformer = (Transformer) this.transformerSparkMLClass.newInstance(); - RuntimeIterator bodyIterator = new ApplyTransformerRuntimeIterator( - this.transformerShortName, - transformer, - ExecutionMode.DATAFRAME, - getMetadata() + Map bodyIterators = new HashMap<>(); + bodyIterators.put( + 0L, + new ApplyTransformerRuntimeIterator( + this.transformerShortName, + transformer, + ExecutionMode.LOCAL, + getMetadata() + ) + ); + bodyIterators.put( + 1L, + new ApplyTransformerRuntimeIterator( + this.transformerShortName, + transformer, + ExecutionMode.DATAFRAME, + getMetadata() + ) ); List paramTypes = Collections.unmodifiableList( Arrays.asList( @@ -143,7 +158,7 @@ public Item next() { returnType ), new DynamicContext(this.currentDynamicContextForLocalExecution.getRumbleRuntimeConfiguration()), - bodyIterator + bodyIterators ); } catch (InstantiationException | IllegalAccessException e) { diff --git a/src/test/java/iq/base/AnnotationsTestsBase.java b/src/test/java/iq/base/AnnotationsTestsBase.java index b0447c30c5..9cdc860e15 100644 --- a/src/test/java/iq/base/AnnotationsTestsBase.java +++ b/src/test/java/iq/base/AnnotationsTestsBase.java @@ -20,6 +20,7 @@ package iq.base; +import org.apache.commons.lang.exception.ExceptionUtils; import org.junit.Assert; import org.rumbledb.api.Item; import org.rumbledb.api.Rumble; @@ -175,7 +176,7 @@ protected void testAnnotations(String path, RumbleRuntimeConfiguration configura checkExpectedOutput(this.currentAnnotation.getOutput(), sequence); } catch (RumbleException exception) { String errorOutput = exception.getMessage(); - exception.printStackTrace(); + errorOutput += "\n" + ExceptionUtils.getStackTrace(exception); Assert.fail("Program did not run when expected to.\nError output: " + errorOutput + "\n"); } } else {