@@ -122,8 +122,7 @@ let verify_name_fresh_udf loc tenv name =
122122 (* variadic functions are currently not in math sigs and aren't
123123 overloadable due to their separate typechecking *)
124124 Stan_math_signatures. is_reduce_sum_fn name
125- || Stan_math_signatures. is_variadic_ode_fn name
126- || Stan_math_signatures. is_variadic_dae_fn name
125+ || Stan_math_signatures. is_stan_math_variadic_function_name name
127126 then Semantic_error. ident_is_stanmath_name loc name |> error
128127 else if Utils. is_unnormalized_distribution name then
129128 Semantic_error. udf_is_unnormalized_fn loc name |> error
@@ -191,14 +190,13 @@ let match_to_rt_option = function
191190 | _ -> None
192191
193192let stan_math_return_type name arg_tys =
194- match name with
195- | x when Stan_math_signatures. is_reduce_sum_fn x ->
193+ match
194+ Hashtbl. find Stan_math_signatures. stan_math_variadic_signatures name
195+ with
196+ | Some {return_type; _} -> Some (UnsizedType. ReturnType return_type)
197+ | None when Stan_math_signatures. is_reduce_sum_fn name ->
196198 Some (UnsizedType. ReturnType UReal )
197- | x when Stan_math_signatures. is_variadic_ode_fn x ->
198- Some (UnsizedType. ReturnType (UArray UVector ))
199- | x when Stan_math_signatures. is_variadic_dae_fn x ->
200- Some (UnsizedType. ReturnType (UArray UVector ))
201- | _ ->
199+ | None ->
202200 SignatureMismatch. matching_stanlib_function name arg_tys
203201 |> match_to_rt_option
204202
@@ -571,30 +569,25 @@ let make_function_variable cf loc id = function
571569
572570let rec check_fn ~is_cond_dist loc cf tenv id (tes : Ast.typed_expression list )
573571 =
574- if Stan_math_signatures. is_reduce_sum_fn id.name then
572+ if Stan_math_signatures. is_stan_math_variadic_function_name id.name then
573+ check_variadic ~is_cond_dist loc cf tenv id tes
574+ else if Stan_math_signatures. is_reduce_sum_fn id.name then
575575 check_reduce_sum ~is_cond_dist loc cf tenv id tes
576- else if Stan_math_signatures. is_variadic_ode_fn id.name then
577- check_variadic_ode ~is_cond_dist loc cf tenv id tes
578- else if Stan_math_signatures. is_variadic_dae_fn id.name then
579- check_variadic_dae ~is_cond_dist loc cf tenv id tes
580576 else check_normal_fn ~is_cond_dist loc tenv id tes
581577
578+ (* * Reduce sum is a special case, even compared to the other
579+ variadic functions, because it is polymorphic in the type of the
580+ first argument. The first, fourth, and fifth arguments must agree,
581+ which is too complicated to be captured declaratively. *)
582582and check_reduce_sum ~is_cond_dist loc cf tenv id tes =
583583 let basic_mismatch () =
584584 let mandatory_args =
585585 UnsizedType. [(AutoDiffable , UArray UReal ); (AutoDiffable , UInt )] in
586586 let mandatory_fun_args =
587587 UnsizedType.
588588 [(AutoDiffable , UArray UReal ); (DataOnly , UInt ); (DataOnly , UInt )] in
589- SignatureMismatch. check_variadic_args true mandatory_args mandatory_fun_args
590- UReal (get_arg_types tes) in
591- let fail () =
592- let expected_args, err =
593- basic_mismatch () |> Result. error |> Option. value_exn in
594- Semantic_error. illtyped_reduce_sum_generic loc id.name
595- (List. map ~f: type_of_expr_typed tes)
596- expected_args err
597- |> error in
589+ SignatureMismatch. check_variadic_args ~allow_lpdf: true mandatory_args
590+ mandatory_fun_args UReal (get_arg_types tes) in
598591 let matching remaining_es fn =
599592 match fn with
600593 | Env.
@@ -611,7 +604,7 @@ and check_reduce_sum ~is_cond_dist loc cf tenv id tes =
611604 let arg_types =
612605 (calculate_autodifftype cf Functions ftype, ftype)
613606 :: get_arg_types remaining_es in
614- SignatureMismatch. check_variadic_args true mandatory_args
607+ SignatureMismatch. check_variadic_args ~allow_lpdf: true mandatory_args
615608 mandatory_fun_args UReal arg_types
616609 | _ -> basic_mismatch () in
617610 match tes with
@@ -633,81 +626,25 @@ and check_reduce_sum ~is_cond_dist loc cf tenv id tes =
633626 (List. map ~f: type_of_expr_typed tes)
634627 expected_args err
635628 |> error )
636- | _ -> fail ()
637-
638- and check_variadic_ode ~is_cond_dist loc cf tenv id tes =
639- let optional_tol_mandatory_args =
640- if Stan_math_signatures. variadic_ode_adjoint_fn = id.name then
641- Stan_math_signatures. variadic_ode_adjoint_ctl_tol_arg_types
642- else if Stan_math_signatures. is_variadic_ode_nonadjoint_tol_fn id.name then
643- Stan_math_signatures. variadic_ode_tol_arg_types
644- else [] in
645- let mandatory_arg_types =
646- Stan_math_signatures. variadic_ode_mandatory_arg_types
647- @ optional_tol_mandatory_args in
648- let fail () =
649- let expected_args, err =
650- SignatureMismatch. check_variadic_args false mandatory_arg_types
651- Stan_math_signatures. variadic_ode_mandatory_fun_args
652- Stan_math_signatures. variadic_ode_fun_return_type (get_arg_types tes)
653- |> Result. error |> Option. value_exn in
654- Semantic_error. illtyped_variadic_ode loc id.name
655- (List. map ~f: type_of_expr_typed tes)
656- expected_args err
657- |> error in
658- let matching remaining_es Env. {type_ = ftype ; _} =
659- let arg_types =
660- (calculate_autodifftype cf Functions ftype, ftype)
661- :: get_arg_types remaining_es in
662- SignatureMismatch. check_variadic_args false mandatory_arg_types
663- Stan_math_signatures. variadic_ode_mandatory_fun_args
664- Stan_math_signatures. variadic_ode_fun_return_type arg_types in
665- match tes with
666- | {expr = Variable fname ; _} :: remaining_es -> (
667- match find_matching_first_order_fn tenv (matching remaining_es) fname with
668- | SignatureMismatch. UniqueMatch (ftype , promotions ) ->
669- let tes = make_function_variable cf loc fname ftype :: remaining_es in
670- mk_typed_expression
671- ~expr:
672- (mk_fun_app ~is_cond_dist
673- (StanLib FnPlain , id, Promotion. promote_list tes promotions) )
674- ~ad_level: (expr_ad_lub tes)
675- ~type_: Stan_math_signatures. variadic_ode_return_type ~loc
676- | AmbiguousMatch ps ->
677- Semantic_error. ambiguous_function_promotion loc fname.name None ps
678- |> error
679- | SignatureErrors (expected_args , err ) ->
680- Semantic_error. illtyped_variadic_ode loc id.name
681- (List. map ~f: type_of_expr_typed tes)
682- expected_args err
683- |> error )
684- | _ -> fail ()
685-
686- and check_variadic_dae ~is_cond_dist loc cf tenv id tes =
687- let optional_tol_mandatory_args =
688- if Stan_math_signatures. is_variadic_dae_tol_fn id.name then
689- Stan_math_signatures. variadic_dae_tol_arg_types
690- else [] in
691- let mandatory_arg_types =
692- Stan_math_signatures. variadic_dae_mandatory_arg_types
693- @ optional_tol_mandatory_args in
694- let fail () =
695- let expected_args, err =
696- SignatureMismatch. check_variadic_args false mandatory_arg_types
697- Stan_math_signatures. variadic_dae_mandatory_fun_args
698- Stan_math_signatures. variadic_dae_fun_return_type (get_arg_types tes)
699- |> Result. error |> Option. value_exn in
700- Semantic_error. illtyped_variadic_dae loc id.name
701- (List. map ~f: type_of_expr_typed tes)
702- expected_args err
703- |> error in
629+ | _ ->
630+ let expected_args, err =
631+ basic_mismatch () |> Result. error |> Option. value_exn in
632+ Semantic_error. illtyped_reduce_sum_generic loc id.name
633+ (List. map ~f: type_of_expr_typed tes)
634+ expected_args err
635+ |> error
636+
637+ and check_variadic ~is_cond_dist loc cf tenv id tes =
638+ let Stan_math_signatures.
639+ {control_args; required_fn_args; required_fn_rt; return_type} =
640+ Hashtbl. find_exn Stan_math_signatures. stan_math_variadic_signatures id.name
641+ in
704642 let matching remaining_es Env. {type_ = ftype ; _} =
705643 let arg_types =
706644 (calculate_autodifftype cf Functions ftype, ftype)
707645 :: get_arg_types remaining_es in
708- SignatureMismatch. check_variadic_args false mandatory_arg_types
709- Stan_math_signatures. variadic_dae_mandatory_fun_args
710- Stan_math_signatures. variadic_dae_fun_return_type arg_types in
646+ SignatureMismatch. check_variadic_args ~allow_lpdf: false control_args
647+ required_fn_args required_fn_rt arg_types in
711648 match tes with
712649 | {expr = Variable fname ; _} :: remaining_es -> (
713650 match find_matching_first_order_fn tenv (matching remaining_es) fname with
@@ -717,17 +654,24 @@ and check_variadic_dae ~is_cond_dist loc cf tenv id tes =
717654 ~expr:
718655 (mk_fun_app ~is_cond_dist
719656 (StanLib FnPlain , id, Promotion. promote_list tes promotions) )
720- ~ad_level: (expr_ad_lub tes)
721- ~type_: Stan_math_signatures. variadic_dae_return_type ~loc
657+ ~ad_level: (expr_ad_lub tes) ~type_: return_type ~loc
722658 | AmbiguousMatch ps ->
723659 Semantic_error. ambiguous_function_promotion loc fname.name None ps
724660 |> error
725661 | SignatureErrors (expected_args , err ) ->
726- Semantic_error. illtyped_variadic_dae loc id.name
662+ Semantic_error. illtyped_variadic loc id.name
727663 (List. map ~f: type_of_expr_typed tes)
728- expected_args err
664+ expected_args required_fn_rt err
729665 |> error )
730- | _ -> fail ()
666+ | _ ->
667+ let expected_args, err =
668+ SignatureMismatch. check_variadic_args ~allow_lpdf: false control_args
669+ required_fn_args required_fn_rt (get_arg_types tes)
670+ |> Result. error |> Option. value_exn in
671+ Semantic_error. illtyped_variadic loc id.name
672+ (List. map ~f: type_of_expr_typed tes)
673+ expected_args required_fn_rt err
674+ |> error
731675
732676and check_funapp loc cf tenv ~is_cond_dist id (es : Ast.typed_expression list ) =
733677 let name_check =
0 commit comments