|
57 | 57 | BasicBlock, |
58 | 58 | Branch, |
59 | 59 | Call, |
| 60 | + ComparisonOp, |
60 | 61 | InitStatic, |
61 | 62 | Integer, |
62 | 63 | LoadAddress, |
|
82 | 83 | none_rprimitive, |
83 | 84 | object_pointer_rprimitive, |
84 | 85 | object_rprimitive, |
| 86 | + pointer_rprimitive, |
85 | 87 | ) |
86 | 88 | from mypyc.irbuild.ast_helpers import is_borrow_friendly_expr, process_conditional |
87 | 89 | from mypyc.irbuild.builder import IRBuilder, create_type_params, int_borrow_friendly_op |
|
103 | 105 | ) |
104 | 106 | from mypyc.primitives.exc_ops import ( |
105 | 107 | error_catch_op, |
| 108 | + error_clear_op, |
| 109 | + err_occurred_op, |
106 | 110 | exc_matches_op, |
107 | 111 | get_exc_info_op, |
108 | 112 | get_exc_value_op, |
|
113 | 117 | reraise_exception_op, |
114 | 118 | restore_exc_info_op, |
115 | 119 | ) |
116 | | -from mypyc.primitives.generic_ops import iter_op, next_raw_op, py_delattr_op |
| 120 | +from mypyc.primitives.generic_ops import iter_op, next_raw_op, py_delattr_op, py_setattr_op |
117 | 121 | from mypyc.primitives.misc_ops import ( |
118 | 122 | check_stop_op, |
119 | 123 | coro_op, |
@@ -940,6 +944,19 @@ def transform_with( |
940 | 944 | is_async: bool, |
941 | 945 | line: int, |
942 | 946 | ) -> None: |
| 947 | + |
| 948 | + if ( |
| 949 | + not is_async |
| 950 | + and isinstance(expr, mypy.nodes.CallExpr) |
| 951 | + and isinstance(expr.callee, mypy.nodes.RefExpr) |
| 952 | + and isinstance(dec := expr.callee.node, mypy.nodes.Decorator) |
| 953 | + and len(dec.decorators) == 1 |
| 954 | + and isinstance(dec1 := dec.decorators[0], mypy.nodes.RefExpr) |
| 955 | + and dec1.node |
| 956 | + and dec1.node.fullname == "contextlib.contextmanager" |
| 957 | + ): |
| 958 | + return _transform_with_contextmanager(builder, expr, target, body, line) |
| 959 | + |
943 | 960 | # This is basically a straight transcription of the Python code in PEP 343. |
944 | 961 | # I don't actually understand why a bunch of it is the way it is. |
945 | 962 | # We could probably optimize the case where the manager is compiled by us, |
@@ -1017,6 +1034,204 @@ def finally_body() -> None: |
1017 | 1034 | ) |
1018 | 1035 |
|
1019 | 1036 |
|
| 1037 | +def _transform_with_contextmanager( |
| 1038 | + builder: IRBuilder, |
| 1039 | + expr: mypy.nodes.CallExpr, |
| 1040 | + target: Lvalue | None, |
| 1041 | + with_body: GenFunc, |
| 1042 | + line: int, |
| 1043 | +) -> None: |
| 1044 | + assert isinstance(expr.callee, mypy.nodes.RefExpr) |
| 1045 | + dec = expr.callee.node |
| 1046 | + assert isinstance(dec, mypy.nodes.Decorator) |
| 1047 | + |
| 1048 | + # mgrv = ctx.__wrapped__(*args, **kwargs) |
| 1049 | + wrapped_call = mypy.nodes.CallExpr( |
| 1050 | + mypy.nodes.MemberExpr(expr.callee, "__wrapped__"), |
| 1051 | + expr.args, |
| 1052 | + expr.arg_kinds, |
| 1053 | + expr.arg_names, |
| 1054 | + ) |
| 1055 | + wrapped_call.line = line |
| 1056 | + gen = builder.accept(wrapped_call) |
| 1057 | + |
| 1058 | + def raise_runtime_error_from_none(msg: str) -> None: |
| 1059 | + runtime_error = builder.load_module_attr_by_fullname("builtins.RuntimeError", line) |
| 1060 | + exc = builder.py_call(runtime_error, [builder.load_str(msg)], line) |
| 1061 | + builder.primitive_op( |
| 1062 | + py_setattr_op, [exc, builder.load_str("__cause__"), builder.none_object()], line |
| 1063 | + ) |
| 1064 | + builder.primitive_op( |
| 1065 | + py_setattr_op, |
| 1066 | + [ |
| 1067 | + exc, |
| 1068 | + builder.load_str("__suppress_context__"), |
| 1069 | + builder.coerce(builder.true(), object_rprimitive, line), |
| 1070 | + ], |
| 1071 | + line, |
| 1072 | + ) |
| 1073 | + builder.call_c(raise_exception_op, [exc], line) |
| 1074 | + builder.add(Unreachable()) |
| 1075 | + |
| 1076 | + # try: |
| 1077 | + # target = next(gen) |
| 1078 | + # except StopIteration: |
| 1079 | + # raise RuntimeError("generator didn't yield") from None |
| 1080 | + mgr_target = builder.call_c(next_raw_op, [gen], line) |
| 1081 | + |
| 1082 | + runtime_block, main_block = BasicBlock(), BasicBlock() |
| 1083 | + builder.add(Branch(mgr_target, runtime_block, main_block, Branch.IS_ERROR)) |
| 1084 | + |
| 1085 | + builder.activate_block(runtime_block) |
| 1086 | + err_occurred = builder.call_c(err_occurred_op, [], line) |
| 1087 | + null = Integer(0, pointer_rprimitive, line) |
| 1088 | + has_error = builder.add(ComparisonOp(err_occurred, null, ComparisonOp.NEQ, line)) |
| 1089 | + implicit_stop_block, error_exc_block = BasicBlock(), BasicBlock() |
| 1090 | + builder.add(Branch(has_error, error_exc_block, implicit_stop_block, Branch.BOOL)) |
| 1091 | + |
| 1092 | + builder.activate_block(error_exc_block) |
| 1093 | + old_exc = builder.maybe_spill(builder.call_c(error_catch_op, [], line)) |
| 1094 | + stop_iteration = builder.load_module_attr_by_fullname("builtins.StopIteration", line) |
| 1095 | + is_stop_iteration = builder.call_c(exc_matches_op, [stop_iteration], line) |
| 1096 | + stop_block, propagate_block = BasicBlock(), BasicBlock() |
| 1097 | + builder.add(Branch(is_stop_iteration, stop_block, propagate_block, Branch.BOOL)) |
| 1098 | + |
| 1099 | + builder.activate_block(propagate_block) |
| 1100 | + builder.call_c(reraise_exception_op, [], NO_TRACEBACK_LINE_NO) |
| 1101 | + builder.add(Unreachable()) |
| 1102 | + |
| 1103 | + builder.activate_block(stop_block) |
| 1104 | + builder.call_c(restore_exc_info_op, [builder.read(old_exc)], line) |
| 1105 | + raise_runtime_error_from_none("generator didn't yield") |
| 1106 | + |
| 1107 | + builder.activate_block(implicit_stop_block) |
| 1108 | + raise_runtime_error_from_none("generator didn't yield") |
| 1109 | + |
| 1110 | + builder.activate_block(main_block) |
| 1111 | + |
| 1112 | + # try: |
| 1113 | + # {body} |
| 1114 | + |
| 1115 | + def try_body() -> None: |
| 1116 | + if target: |
| 1117 | + builder.assign(builder.get_assignment_target(target), mgr_target, line) |
| 1118 | + with_body() |
| 1119 | + |
| 1120 | + # except Exception as e: |
| 1121 | + # try: |
| 1122 | + # gen.throw(e) |
| 1123 | + # except StopIteration as e2: |
| 1124 | + # if e2 is not e: |
| 1125 | + # raise |
| 1126 | + # return |
| 1127 | + # except RuntimeError: |
| 1128 | + # # TODO: check the traceback munging |
| 1129 | + # raise |
| 1130 | + # except BaseException: |
| 1131 | + # # approximately |
| 1132 | + # raise |
| 1133 | + |
| 1134 | + def except_body() -> None: |
| 1135 | + exc_original = builder.call_c(get_exc_value_op, [], line) |
| 1136 | + |
| 1137 | + error_block, no_error_block = BasicBlock(), BasicBlock() |
| 1138 | + |
| 1139 | + builder.builder.push_error_handler(error_block) |
| 1140 | + builder.goto_and_activate(BasicBlock()) |
| 1141 | + builder.py_call(builder.py_get_attr(gen, "throw", line), [exc_original], line) |
| 1142 | + builder.goto(no_error_block) |
| 1143 | + builder.builder.pop_error_handler() |
| 1144 | + |
| 1145 | + builder.activate_block(no_error_block) |
| 1146 | + builder.py_call(builder.py_get_attr(gen, "close", line), [], line) |
| 1147 | + builder.add( |
| 1148 | + RaiseStandardError( |
| 1149 | + RaiseStandardError.RUNTIME_ERROR, "generator didn't stop after throw()", line |
| 1150 | + ) |
| 1151 | + ) |
| 1152 | + builder.add(Unreachable()) |
| 1153 | + |
| 1154 | + builder.activate_block(error_block) |
| 1155 | + throw_old_exc = builder.maybe_spill(builder.call_c(error_catch_op, [], line)) |
| 1156 | + stop_iteration = builder.load_module_attr_by_fullname("builtins.StopIteration", line) |
| 1157 | + is_stop_iteration = builder.call_c(exc_matches_op, [stop_iteration], line) |
| 1158 | + stop_block, propagate_block = BasicBlock(), BasicBlock() |
| 1159 | + builder.add(Branch(is_stop_iteration, stop_block, propagate_block, Branch.BOOL)) |
| 1160 | + |
| 1161 | + builder.activate_block(propagate_block) |
| 1162 | + builder.call_c(reraise_exception_op, [], NO_TRACEBACK_LINE_NO) |
| 1163 | + builder.add(Unreachable()) |
| 1164 | + |
| 1165 | + builder.activate_block(stop_block) |
| 1166 | + stop_exc = builder.call_c(get_exc_value_op, [], line) |
| 1167 | + is_same_exc = builder.binary_op(stop_exc, exc_original, "is", line) |
| 1168 | + |
| 1169 | + suppress_block, reraise_block = BasicBlock(), BasicBlock() |
| 1170 | + builder.add(Branch(is_same_exc, reraise_block, suppress_block, Branch.BOOL)) |
| 1171 | + |
| 1172 | + builder.activate_block(reraise_block) |
| 1173 | + builder.call_c(reraise_exception_op, [], NO_TRACEBACK_LINE_NO) |
| 1174 | + builder.add(Unreachable()) |
| 1175 | + |
| 1176 | + builder.activate_block(suppress_block) |
| 1177 | + builder.call_c(restore_exc_info_op, [builder.read(throw_old_exc)], line) |
| 1178 | + builder.call_c(error_clear_op, [], -1) |
| 1179 | + |
| 1180 | + # TODO: actually do the exceptions |
| 1181 | + handlers = [(None, None, except_body)] |
| 1182 | + |
| 1183 | + # else: |
| 1184 | + # try: |
| 1185 | + # next(gen) |
| 1186 | + # except StopIteration: |
| 1187 | + # pass |
| 1188 | + # else: |
| 1189 | + # try: |
| 1190 | + # raise RuntimeError("generator didn't stop") |
| 1191 | + # finally: |
| 1192 | + # gen.close() |
| 1193 | + |
| 1194 | + def else_body() -> None: |
| 1195 | + value = builder.call_c(next_raw_op, [builder.read(gen)], line) |
| 1196 | + stop_block, close_block = BasicBlock(), BasicBlock() |
| 1197 | + builder.add(Branch(value, stop_block, close_block, Branch.IS_ERROR)) |
| 1198 | + |
| 1199 | + builder.activate_block(close_block) |
| 1200 | + # TODO: this isn't exactly the right order |
| 1201 | + builder.py_call(builder.py_get_attr(gen, "close", line), [], line) |
| 1202 | + builder.add( |
| 1203 | + RaiseStandardError(RaiseStandardError.RUNTIME_ERROR, "generator didn't stop", line) |
| 1204 | + ) |
| 1205 | + builder.add(Unreachable()) |
| 1206 | + |
| 1207 | + builder.activate_block(stop_block) |
| 1208 | + err_occurred = builder.call_c(err_occurred_op, [], line) |
| 1209 | + null = Integer(0, pointer_rprimitive, line) |
| 1210 | + has_error = builder.add(ComparisonOp(err_occurred, null, ComparisonOp.NEQ, line)) |
| 1211 | + implicit_stop_block, error_exc_block = BasicBlock(), BasicBlock() |
| 1212 | + builder.add(Branch(has_error, error_exc_block, implicit_stop_block, Branch.BOOL)) |
| 1213 | + |
| 1214 | + builder.activate_block(error_exc_block) |
| 1215 | + old_exc = builder.maybe_spill(builder.call_c(error_catch_op, [], line)) |
| 1216 | + stop_iteration = builder.load_module_attr_by_fullname("builtins.StopIteration", line) |
| 1217 | + is_stop_iteration = builder.call_c(exc_matches_op, [stop_iteration], line) |
| 1218 | + explicit_stop_block, propagate_block = BasicBlock(), BasicBlock() |
| 1219 | + builder.add(Branch(is_stop_iteration, explicit_stop_block, propagate_block, Branch.BOOL)) |
| 1220 | + |
| 1221 | + builder.activate_block(propagate_block) |
| 1222 | + builder.call_c(reraise_exception_op, [], NO_TRACEBACK_LINE_NO) |
| 1223 | + builder.add(Unreachable()) |
| 1224 | + |
| 1225 | + builder.activate_block(explicit_stop_block) |
| 1226 | + builder.call_c(restore_exc_info_op, [builder.read(old_exc)], line) |
| 1227 | + builder.goto(implicit_stop_block) |
| 1228 | + |
| 1229 | + builder.activate_block(implicit_stop_block) |
| 1230 | + builder.call_c(error_clear_op, [], -1) |
| 1231 | + |
| 1232 | + transform_try_except(builder, try_body, handlers, else_body, line) |
| 1233 | + |
| 1234 | + |
1020 | 1235 | def transform_with_stmt(builder: IRBuilder, o: WithStmt) -> None: |
1021 | 1236 | # Generate separate logic for each expr in it, left to right |
1022 | 1237 | def generate(i: int) -> None: |
|
0 commit comments