@@ -781,6 +781,9 @@ def _default_call_value(param: Param) -> str:
781781 if param .aten_type in _NULLOPT_BY_TYPE :
782782 return _NULLOPT_BY_TYPE [param .aten_type ]
783783
784+ if param .default is not None :
785+ return _translate_default (param )
786+
784787 return param .hidden_value ()
785788
786789
@@ -1144,6 +1147,8 @@ def _generate_torch_header(name: str, ops: list[Op]) -> str:
11441147def _generate_torch_method_source (name : str , op : Op ) -> str :
11451148 op_type = _op_cpp_type (name )
11461149 conversion_lines = []
1150+ out_device_index = f"{ op .out_params [0 ].api_name } .device().index()"
1151+ conversion_lines .append (f" const auto device_index = { out_device_index } ;" )
11471152
11481153 def _optional_aten_type (param : Param ) -> str :
11491154 return _NULLOPT_BY_TYPE [param .aten_type ].removesuffix ("{}" )
@@ -1157,7 +1162,7 @@ def _optional_aten_value(schema_param: Param, api_param: Param) -> str:
11571162 return (
11581163 f"ToAtenTensor<kDev>({ data_expr } , { api_name } ->shape(), "
11591164 f"{ api_name } ->strides(), { api_name } ->dtype(), "
1160- f"device_index_ )"
1165+ f"device_index )"
11611166 )
11621167
11631168 if schema_param .aten_type == "Scalar?" :
@@ -1211,7 +1216,7 @@ def _append_optional_conversion(schema_param: Param, api_param: Param) -> None:
12111216 conversion_lines .append (
12121217 f" auto at_{ param .name } = ToAtenTensor<kDev>(\n "
12131218 f" { data_expr } , { api_name } _shape_, { api_name } _strides_,\n "
1214- f" { api_name } _type_, device_index_ );"
1219+ f" { api_name } _type_, device_index );"
12151220 )
12161221
12171222 for schema_index , param in enumerate (op .params ):
0 commit comments