|
18 | 18 | import java.util.HashMap; |
19 | 19 | import java.util.List; |
20 | 20 | import java.util.Map; |
| 21 | +import java.util.Objects; |
21 | 22 | import java.util.Optional; |
| 23 | +import java.util.StringJoiner; |
| 24 | +import java.util.stream.Collectors; |
22 | 25 | import org.apache.calcite.rel.type.RelDataType; |
23 | 26 | import org.apache.calcite.rex.RexBuilder; |
24 | 27 | import org.apache.calcite.rex.RexNode; |
|
34 | 37 | import org.apache.calcite.sql.type.SqlTypeName; |
35 | 38 | import org.apache.calcite.sql.validate.SqlUserDefinedFunction; |
36 | 39 | import org.checkerframework.checker.nullness.qual.Nullable; |
| 40 | +import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory; |
| 41 | +import org.opensearch.sql.exception.ExpressionEvaluationException; |
37 | 42 | import org.opensearch.sql.executor.QueryType; |
38 | 43 |
|
39 | 44 | public class PPLFuncImpTable { |
@@ -143,8 +148,20 @@ public RexNode resolve( |
143 | 148 | functionName, argTypes, e.getMessage()), |
144 | 149 | e); |
145 | 150 | } |
146 | | - throw new IllegalArgumentException( |
147 | | - String.format("Cannot resolve function: %s, arguments: %s", functionName, argTypes)); |
| 151 | + StringJoiner joiner = new StringJoiner(","); |
| 152 | + for (var implement : implementList) { |
| 153 | + joiner.add(implement.getKey().typeChecker().getAllowedSignatures()); |
| 154 | + } |
| 155 | + String actualSignature = |
| 156 | + "[" |
| 157 | + + argTypes.stream() |
| 158 | + .map(OpenSearchTypeFactory::convertRelDataTypeToExprType) |
| 159 | + .map(Objects::toString) |
| 160 | + .collect(Collectors.joining(",")) |
| 161 | + + "]"; |
| 162 | + throw new ExpressionEvaluationException( |
| 163 | + String.format( |
| 164 | + "%s function expects {%s}, but got %s", functionName, joiner, actualSignature)); |
148 | 165 | } |
149 | 166 |
|
150 | 167 | @SuppressWarnings({"UnusedReturnValue", "SameParameterValue"}) |
@@ -364,57 +381,65 @@ void populate() { |
364 | 381 | // Note, make the implementation an individual class if too complex. |
365 | 382 | register( |
366 | 383 | TRIM, |
367 | | - ((FunctionImp1) |
| 384 | + createFunctionImpWithTypeChecker( |
368 | 385 | (builder, arg) -> |
369 | 386 | builder.makeCall( |
370 | 387 | SqlStdOperatorTable.TRIM, |
371 | 388 | builder.makeFlag(Flag.BOTH), |
372 | 389 | builder.makeLiteral(" "), |
373 | | - arg))); |
| 390 | + arg), |
| 391 | + family(SqlTypeFamily.STRING))); |
| 392 | + |
374 | 393 | register( |
375 | 394 | LTRIM, |
376 | | - ((FunctionImp1) |
| 395 | + createFunctionImpWithTypeChecker( |
377 | 396 | (builder, arg) -> |
378 | 397 | builder.makeCall( |
379 | 398 | SqlStdOperatorTable.TRIM, |
380 | 399 | builder.makeFlag(Flag.LEADING), |
381 | 400 | builder.makeLiteral(" "), |
382 | | - arg))); |
| 401 | + arg), |
| 402 | + family(SqlTypeFamily.STRING))); |
383 | 403 | register( |
384 | 404 | RTRIM, |
385 | | - ((FunctionImp1) |
| 405 | + createFunctionImpWithTypeChecker( |
386 | 406 | (builder, arg) -> |
387 | 407 | builder.makeCall( |
388 | 408 | SqlStdOperatorTable.TRIM, |
389 | 409 | builder.makeFlag(Flag.TRAILING), |
390 | 410 | builder.makeLiteral(" "), |
391 | | - arg))); |
| 411 | + arg), |
| 412 | + family(SqlTypeFamily.STRING))); |
392 | 413 | register( |
393 | 414 | STRCMP, |
394 | | - ((FunctionImp2) |
395 | | - (builder, arg1, arg2) -> builder.makeCall(SqlLibraryOperators.STRCMP, arg2, arg1))); |
| 415 | + createFunctionImpWithTypeChecker( |
| 416 | + (builder, arg1, arg2) -> builder.makeCall(SqlLibraryOperators.STRCMP, arg2, arg1), |
| 417 | + family(SqlTypeFamily.STRING, SqlTypeFamily.STRING))); |
396 | 418 | register( |
397 | 419 | LOG, |
398 | | - ((FunctionImp2) |
399 | | - (builder, arg1, arg2) -> builder.makeCall(SqlLibraryOperators.LOG, arg2, arg1))); |
| 420 | + createFunctionImpWithTypeChecker( |
| 421 | + (builder, arg1, arg2) -> builder.makeCall(SqlLibraryOperators.LOG, arg2, arg1), |
| 422 | + family(SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC))); |
400 | 423 | register( |
401 | 424 | LOG, |
402 | | - ((FunctionImp1) |
| 425 | + createFunctionImpWithTypeChecker( |
403 | 426 | (builder, arg) -> |
404 | 427 | builder.makeCall( |
405 | 428 | SqlLibraryOperators.LOG, |
406 | 429 | arg, |
407 | | - builder.makeApproxLiteral(BigDecimal.valueOf(Math.E))))); |
| 430 | + builder.makeApproxLiteral(BigDecimal.valueOf(Math.E))), |
| 431 | + family(SqlTypeFamily.NUMERIC))); |
408 | 432 | // SqlStdOperatorTable.SQRT is declared but not implemented. The call to SQRT in Calcite is |
409 | 433 | // converted to POWER(x, 0.5). |
410 | 434 | register( |
411 | 435 | SQRT, |
412 | | - ((FunctionImp1) |
| 436 | + createFunctionImpWithTypeChecker( |
413 | 437 | (builder, arg) -> |
414 | 438 | builder.makeCall( |
415 | 439 | SqlStdOperatorTable.POWER, |
416 | 440 | arg, |
417 | | - builder.makeApproxLiteral(BigDecimal.valueOf(0.5))))); |
| 441 | + builder.makeApproxLiteral(BigDecimal.valueOf(0.5))), |
| 442 | + family(SqlTypeFamily.NUMERIC))); |
418 | 443 | register( |
419 | 444 | TYPEOF, |
420 | 445 | (FunctionImp1) |
@@ -465,4 +490,44 @@ public PPLTypeChecker getTypeChecker() { |
465 | 490 | return family(booleanFamily, booleanFamily); |
466 | 491 | } |
467 | 492 | } |
| 493 | + |
| 494 | + @FunctionalInterface |
| 495 | + private interface RexUnaryResolver { |
| 496 | + RexNode apply(RexBuilder builder, RexNode node); |
| 497 | + } |
| 498 | + |
| 499 | + @FunctionalInterface |
| 500 | + private interface RexBinaryResolver { |
| 501 | + RexNode apply(RexBuilder builder, RexNode arg1, RexNode arg2); |
| 502 | + } |
| 503 | + |
| 504 | + private static FunctionImp createFunctionImpWithTypeChecker( |
| 505 | + RexUnaryResolver resolver, PPLTypeChecker typeChecker) { |
| 506 | + return new FunctionImp1() { |
| 507 | + @Override |
| 508 | + public RexNode resolve(RexBuilder builder, RexNode arg1) { |
| 509 | + return resolver.apply(builder, arg1); |
| 510 | + } |
| 511 | + |
| 512 | + @Override |
| 513 | + public PPLTypeChecker getTypeChecker() { |
| 514 | + return typeChecker; |
| 515 | + } |
| 516 | + }; |
| 517 | + } |
| 518 | + |
| 519 | + private static FunctionImp createFunctionImpWithTypeChecker( |
| 520 | + RexBinaryResolver resolver, PPLTypeChecker typeChecker) { |
| 521 | + return new FunctionImp2() { |
| 522 | + @Override |
| 523 | + public RexNode resolve(RexBuilder builder, RexNode arg1, RexNode arg2) { |
| 524 | + return resolver.apply(builder, arg1, arg2); |
| 525 | + } |
| 526 | + |
| 527 | + @Override |
| 528 | + public PPLTypeChecker getTypeChecker() { |
| 529 | + return typeChecker; |
| 530 | + } |
| 531 | + }; |
| 532 | + } |
468 | 533 | } |
0 commit comments