|
4 | 4 | from typing import Dict, Tuple, List, Callable |
5 | 5 | import pathlib |
6 | 6 | import re |
| 7 | +import os |
7 | 8 |
|
8 | 9 | BREAKLINE_TOKEN = "<n>" |
9 | 10 |
|
| 11 | +BREAKLINE_PER_FILE_TYPE = { |
| 12 | + ".md": "\n", |
| 13 | + ".py": r"\n" |
| 14 | +} |
| 15 | + |
10 | 16 | # --------------------------------------------------------------------------- # |
11 | 17 | # User-facing API |
12 | 18 | # --------------------------------------------------------------------------- # |
13 | 19 | def text_to_patch(text: str, orig: Dict[str, str]) -> Tuple[Patch, int]: |
14 | 20 | """High-level function to parse patch text against original file content.""" |
15 | 21 | lines = text.splitlines() |
| 22 | + for i, line in enumerate(lines): |
| 23 | + if line.startswith(("@", "***")): |
| 24 | + continue |
| 25 | + |
| 26 | + elif line.startswith("---") or not line.startswith(("+", "-", " ")): |
| 27 | + lines[i] = f" {line}" |
| 28 | + |
| 29 | + # print(f"\n\n{lines[-2:]=}") |
| 30 | + |
16 | 31 | if not lines or not Parser._norm(lines[0]).startswith("*** Begin Patch"): |
17 | 32 | raise DiffError("Invalid patch text - must start with '*** Begin Patch'.") |
18 | 33 | if not Parser._norm(lines[-1]) == "*** End Patch": |
@@ -146,8 +161,9 @@ def open_file(path: str) -> str: |
146 | 161 | def write_file(path: str, content: str) -> None: |
147 | 162 | target = pathlib.Path(path) |
148 | 163 | target.parent.mkdir(parents=True, exist_ok=True) |
| 164 | + _, ext = os.path.splitext(target) |
149 | 165 | with target.open("wt", encoding="utf-8", newline="\n") as fh: |
150 | | - fh.write(content.replace(BREAKLINE_TOKEN, "\\n")) |
| 166 | + fh.write(content.replace(BREAKLINE_TOKEN, BREAKLINE_PER_FILE_TYPE.get(ext, r"\n"))) |
151 | 167 |
|
152 | 168 |
|
153 | 169 | def remove_file(path: str) -> None: |
|
0 commit comments