-
Notifications
You must be signed in to change notification settings - Fork 25
Expand file tree
/
Copy pathreplacement.py
More file actions
936 lines (753 loc) · 35.7 KB
/
Copy pathreplacement.py
File metadata and controls
936 lines (753 loc) · 35.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
"""Java code replacement.
This module provides functionality to replace function implementations
in Java source code while preserving formatting and structure.
Supports optimizations that add:
- New static fields
- New helper methods
- Additional class-level members
"""
from __future__ import annotations
import logging
import re
import textwrap
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Any
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.java.parser import get_java_analyzer
if TYPE_CHECKING:
from codeflash.languages.java.parser import JavaAnalyzer
logger = logging.getLogger(__name__)
@dataclass
class ParsedOptimization:
"""Parsed optimization containing method and additional class members."""
target_method_source: str
new_fields: list[str] # Source text of new fields to add
helpers_before_target: list[str] = field(default_factory=list) # Helpers appearing before target in optimized code
helpers_after_target: list[str] = field(default_factory=list) # Helpers appearing after target in optimized code
modified_constructors: list[str] = field(default_factory=list) # Constructor sources that need to replace originals
def _parse_optimization_source(new_source: str, target_method_name: str, analyzer: JavaAnalyzer) -> ParsedOptimization:
"""Parse optimization source to extract method and additional class members.
The new_source may contain:
- Just a method definition
- A class with the method and additional static fields/helper methods
Args:
new_source: The optimization source code.
target_method_name: Name of the method being optimized.
analyzer: JavaAnalyzer instance.
Returns:
ParsedOptimization with the method and any additional members.
If the generated code contains no method matching target_method_name,
target_method_source will be empty to signal that the candidate is invalid.
"""
new_fields: list[str] = []
target_method_source = new_source # Default to the whole source
# Check if this is a full class or just a method
classes = analyzer.find_classes(new_source)
helpers_before_target: list[str] = []
helpers_after_target: list[str] = []
modified_constructors: list[str] = []
if classes:
# It's a class - extract components
methods = analyzer.find_methods(new_source)
# Find the target method and its index among all methods
target_method = None
target_method_index: int | None = None
for i, method in enumerate(methods):
if method.name == target_method_name:
target_method = method
target_method_index = i
break
if target_method:
# Extract target method source (including Javadoc if present)
lines = new_source.splitlines(keepends=True)
start = (target_method.javadoc_start_line or target_method.start_line) - 1
end = target_method.end_line
target_method_source = "".join(lines[start:end])
else:
logger.warning(
"Generated class does not contain target method '%s'. Skipping candidate.", target_method_name
)
target_method_source = ""
# Extract helper methods, categorised by position relative to the target.
# Skip methods whose line range falls entirely inside the target method's
# range, as these belong to anonymous/inner classes inside the target body
# and must not be hoisted out as top-level class members.
lines = new_source.splitlines(keepends=True)
for i, method in enumerate(methods):
if method.name != target_method_name:
# Skip methods nested inside the target (e.g. anonymous class methods)
if target_method and (
method.start_line >= target_method.start_line and method.end_line <= target_method.end_line
):
continue
start = (method.javadoc_start_line or method.start_line) - 1
end = method.end_line
helper_source = "".join(lines[start:end])
if target_method_index is None or i < target_method_index:
helpers_before_target.append(helper_source)
else:
helpers_after_target.append(helper_source)
# Extract constructors that belong to the same class as the target method.
# When the LLM adds a new field (e.g. a cached value), it also updates the
# constructors to initialize it. We must replace those constructors in the
# original source, otherwise the new final field will be uninitialized
# (Bug 3: uninitialized variable errors).
# Use line-sliced text (same as helper methods) so that the leading whitespace
# is preserved and _dedent_member can normalise indentation correctly.
if target_method:
target_class_name_for_ctors = target_method.class_name
new_constructors = analyzer.find_constructors(new_source, class_name=target_class_name_for_ctors)
ctor_lines = new_source.splitlines(keepends=True)
for c in new_constructors:
ctor_start = (c.javadoc_start_line or c.start_line) - 1
ctor_end = c.end_line
modified_constructors.append("".join(ctor_lines[ctor_start:ctor_end]))
# Extract fields scoped to the target method's class only.
# Without class filtering, fields from inner/anonymous classes would be
# incorrectly injected into the outer target class.
target_class_name = target_method.class_name if target_method else None
fields = analyzer.find_fields(new_source, class_name=target_class_name)
for f in fields:
if f.source_text:
new_fields.append(f.source_text)
else:
# No class found — generated code is a standalone method (or snippet).
# Validate that it actually defines the target method; if it defines a
# *different* method, applying it would corrupt the original source.
standalone_methods = analyzer.find_methods(new_source)
if standalone_methods:
matching = [m for m in standalone_methods if m.name == target_method_name]
if not matching:
logger.warning(
"Generated standalone method '%s' does not match target method '%s'. "
"Skipping candidate to avoid corrupting the source.",
standalone_methods[0].name,
target_method_name,
)
target_method_source = ""
return ParsedOptimization(
target_method_source=target_method_source,
new_fields=new_fields,
helpers_before_target=helpers_before_target,
helpers_after_target=helpers_after_target,
modified_constructors=modified_constructors,
)
def _dedent_member(source: str) -> str:
"""Strip the common leading whitespace from a class member source."""
return textwrap.dedent(source).strip()
def _lines_to_insert_byte(source_lines: list[str], end_line_1indexed: int) -> int:
"""Return the byte offset immediately after the given 1-indexed line."""
return sum(len(ln.encode("utf8")) for ln in source_lines[:end_line_1indexed])
def _insert_class_members(
source: str,
class_name: str,
fields: list[str],
helpers_before_target: list[str],
helpers_after_target: list[str],
target_method_name: str | None,
analyzer: JavaAnalyzer,
) -> str:
"""Insert new class members (fields and helper methods) into a class.
Fields are inserted after the last existing field declaration (or at the
start of the class body when no fields exist yet).
Helpers that appear *before* the target method in the optimized code are
inserted immediately before that method in the original source.
Helpers that appear *after* the target method in the optimized code are
appended at the end of the class body (before the closing brace).
All injected code is properly dedented then re-indented to the class member
level, which fixes the extra-indentation bug that arose when the extracted
source retained its original class-level whitespace prefix.
Args:
source: The source code.
class_name: Name of the class to modify.
fields: Field source texts to insert.
helpers_before_target: Helper methods that precede the target in the optimised code.
helpers_after_target: Helper methods that follow the target in the optimised code.
target_method_name: Name of the method being replaced (used to locate insertion point).
analyzer: JavaAnalyzer instance.
Returns:
Modified source code.
"""
if not fields and not helpers_before_target and not helpers_after_target:
return source
def get_target_class_and_body(src: str) -> tuple[Any, Any]:
for cls in analyzer.find_classes(src):
if cls.name == class_name:
body = cls.node.child_by_field_name("body")
return cls, body
return None, None
target_class, body_node = get_target_class_and_body(source)
if not target_class or not body_node:
logger.warning("Could not find class %s to insert members", class_name)
return source
lines_list = source.splitlines(keepends=True)
class_line = target_class.start_line - 1
class_indent = _get_indentation(lines_list[class_line]) if class_line < len(lines_list) else ""
member_indent = class_indent + " "
def format_member(raw: str) -> str:
"""Dedent then re-indent a class member to the correct level."""
member_lines = _dedent_member(raw).splitlines(keepends=True)
indented = _apply_indentation(member_lines, member_indent)
if indented and not indented.endswith("\n"):
indented += "\n"
return indented
result = source
# ── 1. Insert fields after the last existing field (Bug 2 fix) ──────────
if fields:
_, body_node = get_target_class_and_body(result)
if body_node:
existing_fields = analyzer.find_fields(result, class_name=class_name)
result_lines = result.splitlines(keepends=True)
result_bytes = result.encode("utf8")
if existing_fields:
last_field = max(existing_fields, key=lambda f: f.end_line)
insert_byte = _lines_to_insert_byte(result_lines, last_field.end_line)
field_text = "".join(format_member(f) for f in fields)
else:
insert_byte = body_node.start_byte + 1 # after opening brace
field_text = "\n" + "".join(format_member(f) for f in fields)
before = result_bytes[:insert_byte]
after = result_bytes[insert_byte:]
result = (before + field_text.encode("utf8") + after).decode("utf8")
# ── 2. Insert helpers-before-target just before the target method (Bug 3 fix) ─
if helpers_before_target and target_method_name:
result_methods = analyzer.find_methods(result)
target_methods = [m for m in result_methods if m.name == target_method_name]
if target_methods:
target_m = target_methods[0]
insert_line = (target_m.javadoc_start_line or target_m.start_line) - 1 # 0-indexed
result_lines = result.splitlines(keepends=True)
insert_byte = sum(len(ln.encode("utf8")) for ln in result_lines[:insert_line])
result_bytes = result.encode("utf8")
# Each helper followed by a blank line (Bug 4 fix)
method_text = "".join(format_member(h) + "\n" for h in helpers_before_target)
before = result_bytes[:insert_byte]
after = result_bytes[insert_byte:]
result = (before + method_text.encode("utf8") + after).decode("utf8")
# ── 3. Append helpers-after-target before the closing brace (Bug 4 fix) ─
if helpers_after_target:
_, body_node = get_target_class_and_body(result)
if body_node:
result_bytes = result.encode("utf8")
insert_point = body_node.end_byte - 1 # before closing brace
method_text = "\n" + "".join(format_member(h) + "\n" for h in helpers_after_target)
before = result_bytes[:insert_point]
after = result_bytes[insert_point:]
result = (before + method_text.encode("utf8") + after).decode("utf8")
return result
def _replace_constructors(
source: str, class_name: str, new_constructor_sources: list[str], analyzer: JavaAnalyzer
) -> str:
"""Replace constructors in source with updated versions from the optimization.
Matches constructors by their formal parameter signature. When a matching
constructor is found in the original source it is replaced in-place,
preserving the original indentation. Constructors for which no match
exists in the original are silently skipped (they would need to be inserted
as new members, which is out of scope for this helper).
Args:
source: The original source code to modify.
class_name: Name of the class whose constructors should be replaced.
new_constructor_sources: Source text of each updated constructor.
analyzer: JavaAnalyzer instance.
Returns:
Modified source code with constructors replaced.
"""
if not new_constructor_sources:
return source
original_constructors = analyzer.find_constructors(source, class_name=class_name)
if not original_constructors:
return source
result = source
for new_ctor_src in new_constructor_sources:
# Wrap in a dummy class so the parser can handle a bare constructor
dummy = f"class __Dummy__ {{\n{new_ctor_src}\n}}"
parsed_new = analyzer.find_constructors(dummy)
if not parsed_new:
continue
new_ctor = parsed_new[0]
new_params = (new_ctor.formal_parameters_text or "()").strip()
# Find the matching constructor in the current (potentially already
# modified) source by parameter signature.
current_constructors = analyzer.find_constructors(result, class_name=class_name)
matching = None
for orig in current_constructors:
if (orig.formal_parameters_text or "()").strip() == new_params:
matching = orig
break
if not matching:
logger.debug("No matching constructor with params %s found in class %s; skipping.", new_params, class_name)
continue
# Determine replacement range (include Javadoc if present)
ctor_start = matching.javadoc_start_line or matching.start_line
ctor_end = matching.end_line
lines = result.splitlines(keepends=True)
original_first_line = lines[ctor_start - 1] if ctor_start <= len(lines) else ""
indent = _get_indentation(original_first_line)
# Dedent first to remove any class-level indentation, then re-apply
# the correct indentation (same as _insert_class_members / format_member).
new_ctor_lines = _dedent_member(new_ctor_src).splitlines(keepends=True)
indented_new_ctor = _apply_indentation(new_ctor_lines, indent)
if indented_new_ctor and not indented_new_ctor.endswith("\n"):
indented_new_ctor += "\n"
before = lines[: ctor_start - 1]
after = lines[ctor_end:]
result = "".join(before) + indented_new_ctor + "".join(after)
logger.debug("Replaced constructor %s(%s) in class %s", class_name, new_params, class_name)
return result
def replace_function(
source: str, function: FunctionToOptimize, new_source: str, analyzer: JavaAnalyzer | None = None
) -> str:
"""Replace a function in source code with new implementation.
Supports optimizations that include:
- Just the method being optimized
- A class with the method plus additional static fields and helper methods
When the new_source contains a full class with additional members,
those members are also added to the original source.
Preserves:
- Surrounding whitespace and formatting
- Javadoc comments (if they should be preserved)
- Annotations
Args:
source: Original source code.
function: FunctionToOptimize identifying the function to replace.
new_source: New function source code (may include class with helpers).
analyzer: Optional JavaAnalyzer instance.
Returns:
Modified source code with function replaced and any new members added.
"""
analyzer = analyzer or get_java_analyzer()
func_name = function.function_name
func_start_line = function.starting_line
func_end_line = function.ending_line
# Parse the optimization to extract components.
parsed = _parse_optimization_source(new_source, func_name, analyzer)
if not parsed.target_method_source.strip():
logger.warning("No valid replacement found for method '%s'. Returning original source.", func_name)
return source
# Find the method in the original source
methods = analyzer.find_methods(source)
target_method = None
target_overload_index = 0 # Track which overload we're targeting
# Find all methods matching the name (there may be overloads)
matching_methods = [
m
for m in methods
if m.name == func_name and (function.class_name is None or m.class_name == function.class_name)
]
if len(matching_methods) == 1:
# Only one method with this name - use it
target_method = matching_methods[0]
target_overload_index = 0
elif len(matching_methods) > 1:
# Multiple overloads - use line numbers to find the exact one
logger.debug(
"Found %d overloads of %s. Function start_line=%s, end_line=%s",
len(matching_methods),
func_name,
func_start_line,
func_end_line,
)
for i, m in enumerate(matching_methods):
logger.debug(" Overload %d: lines %d-%d", i, m.start_line, m.end_line)
if func_start_line and func_end_line:
for i, method in enumerate(matching_methods):
# Check if the line numbers are close (account for minor differences
# that can occur due to different parsing or file transformations)
# Use a tolerance of 5 lines to handle edge cases
if abs(method.start_line - func_start_line) <= 5:
target_method = method
target_overload_index = i
logger.debug(
"Matched overload %d at lines %d-%d (target: %d-%d)",
i,
method.start_line,
method.end_line,
func_start_line,
func_end_line,
)
break
if not target_method:
# Fallback: use the first match
logger.warning("Multiple overloads of %s found but no line match, using first match", func_name)
target_method = matching_methods[0]
target_overload_index = 0
if not target_method:
logger.error("Could not find method %s in source", func_name)
return source
# Get the class name for inserting new members
class_name = target_method.class_name or function.class_name
# First, add any new fields and helper methods to the class
if class_name and (parsed.new_fields or parsed.helpers_before_target or parsed.helpers_after_target):
# Filter out fields/methods that already exist
existing_methods = {m.name for m in methods}
existing_fields = {f.name for f in analyzer.find_fields(source)}
# Filter helper methods (before target)
new_helpers_before = []
for helper_src in parsed.helpers_before_target:
helper_methods = analyzer.find_methods(helper_src)
if helper_methods and helper_methods[0].name not in existing_methods:
new_helpers_before.append(helper_src)
# Filter helper methods (after target)
new_helpers_after = []
for helper_src in parsed.helpers_after_target:
helper_methods = analyzer.find_methods(helper_src)
if helper_methods and helper_methods[0].name not in existing_methods:
new_helpers_after.append(helper_src)
# Filter fields
new_fields_to_add = []
for field_src in parsed.new_fields:
# Parse field to get its name by wrapping in a dummy class
# (find_fields requires class context to parse field declarations)
dummy_class = f"class __DummyClass__ {{\n{field_src}\n}}"
field_infos = analyzer.find_fields(dummy_class)
for field_info in field_infos:
if field_info.name not in existing_fields:
new_fields_to_add.append(field_src)
break # Only add once per field declaration
if new_fields_to_add or new_helpers_before or new_helpers_after:
logger.debug(
"Adding %d new fields, %d before-helpers, %d after-helpers to class %s",
len(new_fields_to_add),
len(new_helpers_before),
len(new_helpers_after),
class_name,
)
source = _insert_class_members(
source, class_name, new_fields_to_add, new_helpers_before, new_helpers_after, func_name, analyzer
)
# Re-find the target method after modifications
# Line numbers have shifted, but the relative order of overloads is preserved
# Use the target_overload_index we saved earlier
methods = analyzer.find_methods(source)
matching_methods = [
m
for m in methods
if m.name == func_name and (function.class_name is None or m.class_name == function.class_name)
]
if matching_methods and target_overload_index < len(matching_methods):
target_method = matching_methods[target_overload_index]
logger.debug(
"Re-found target method at overload index %d (lines %d-%d after shift)",
target_overload_index,
target_method.start_line,
target_method.end_line,
)
else:
logger.error(
"Lost target method %s after adding members (had index %d, found %d overloads)",
func_name,
target_overload_index,
len(matching_methods),
)
return source
# Determine replacement range
# Include Javadoc if present
start_line = target_method.javadoc_start_line or target_method.start_line
end_line = target_method.end_line
# Split source into lines
lines = source.splitlines(keepends=True)
# Get indentation from the original method
original_first_line = lines[start_line - 1] if start_line <= len(lines) else ""
indent = _get_indentation(original_first_line)
# Ensure new source has correct indentation
method_source = parsed.target_method_source
new_source_lines = method_source.splitlines(keepends=True)
indented_new_source = _apply_indentation(new_source_lines, indent)
# Ensure the new source ends with a newline to avoid concatenation issues
if indented_new_source and not indented_new_source.endswith("\n"):
indented_new_source += "\n"
# Build the result
before = lines[: start_line - 1] # Lines before the method
after = lines[end_line:] # Lines after the method
result = "".join(before) + indented_new_source + "".join(after)
# Replace modified constructors if the optimization introduced new field
# initializations (Bug 3: uninitialized variable errors).
if class_name and parsed.modified_constructors:
result = _replace_constructors(result, class_name, parsed.modified_constructors, analyzer)
return result
def _get_indentation(line: str) -> str:
"""Extract the indentation from a line.
Args:
line: The line to analyze.
Returns:
The indentation string (spaces/tabs).
"""
match = re.match(r"^(\s*)", line)
return match.group(1) if match else ""
def _apply_indentation(lines: list[str], base_indent: str) -> str:
"""Apply indentation to all lines.
Args:
lines: Lines to indent.
base_indent: Base indentation to apply.
Returns:
Indented source code.
"""
if not lines:
return ""
# Detect the existing indentation from the first non-empty line
# This includes Javadoc/comment lines to handle them correctly
existing_indent = ""
for line in lines:
if line.strip(): # First non-empty line
existing_indent = _get_indentation(line)
break
result_lines = []
for line in lines:
if not line.strip():
result_lines.append(line)
else:
# Remove existing indentation and apply new base indentation
stripped_line = line.lstrip()
# Calculate relative indentation
line_indent = _get_indentation(line)
# When existing_indent is empty (first line has no indent), the relative
# indent is the full line indent. Otherwise, calculate the difference.
if line_indent.startswith(existing_indent):
relative_indent = line_indent[len(existing_indent) :]
else:
relative_indent = ""
result_lines.append(base_indent + relative_indent + stripped_line)
return "".join(result_lines)
def replace_method_body(
source: str, function: FunctionToOptimize, new_body: str, analyzer: JavaAnalyzer | None = None
) -> str:
"""Replace just the body of a method, preserving signature.
Args:
source: Original source code.
function: FunctionToOptimize identifying the function.
new_body: New method body (code between braces).
analyzer: Optional JavaAnalyzer instance.
Returns:
Modified source code.
"""
analyzer = analyzer or get_java_analyzer()
source_bytes = source.encode("utf8")
func_name = function.function_name
# Find the method
methods = analyzer.find_methods(source)
target_method = None
for method in methods:
if method.name == func_name:
if function.class_name is None or method.class_name == function.class_name:
target_method = method
break
if not target_method:
logger.error("Could not find method %s", func_name)
return source
# Find the body node
body_node = target_method.node.child_by_field_name("body")
if not body_node:
logger.error("Method %s has no body (abstract?)", func_name)
return source
# Get the body's byte positions
body_start = body_node.start_byte
body_end = body_node.end_byte
# Get indentation
body_start_line = body_node.start_point[0]
lines = source.splitlines(keepends=True)
base_indent = _get_indentation(lines[body_start_line]) if body_start_line < len(lines) else " "
# Format the new body
new_body = new_body.strip()
if not new_body.startswith("{"):
new_body = "{\n" + base_indent + " " + new_body
if not new_body.endswith("}"):
new_body = new_body + "\n" + base_indent + "}"
# Replace the body
before = source_bytes[:body_start]
after = source_bytes[body_end:]
return (before + new_body.encode("utf8") + after).decode("utf8")
def insert_method(
source: str,
class_name: str,
method_source: str,
position: str = "end", # "end" or "start"
analyzer: JavaAnalyzer | None = None,
) -> str:
"""Insert a new method into a class.
Args:
source: The source code.
class_name: Name of the class to insert into.
method_source: Source code of the method to insert.
position: Where to insert ("end" or "start" of class body).
analyzer: Optional JavaAnalyzer instance.
Returns:
Source code with method inserted.
"""
analyzer = analyzer or get_java_analyzer()
# Find the class
classes = analyzer.find_classes(source)
target_class = None
for cls in classes:
if cls.name == class_name:
target_class = cls
break
if not target_class:
logger.error("Could not find class %s", class_name)
return source
# Find the class body
body_node = target_class.node.child_by_field_name("body")
if not body_node:
logger.error("Class %s has no body", class_name)
return source
# Get insertion point
source_bytes = source.encode("utf8")
if position == "end":
# Insert before the closing brace
insert_point = body_node.end_byte - 1
else:
# Insert after the opening brace
insert_point = body_node.start_byte + 1
# Get indentation (typically 4 spaces inside a class)
lines = source.splitlines(keepends=True)
class_line = target_class.start_line - 1
class_indent = _get_indentation(lines[class_line]) if class_line < len(lines) else ""
method_indent = class_indent + " "
# Format the method
method_lines = method_source.strip().splitlines(keepends=True)
indented_method = _apply_indentation(method_lines, method_indent)
# Ensure the indented method ends with a newline
if indented_method and not indented_method.endswith("\n"):
indented_method += "\n"
# Insert the method
before = source_bytes[:insert_point]
after = source_bytes[insert_point:]
# Use single newline as separator
separator = "\n"
return (before + separator.encode("utf8") + indented_method.encode("utf8") + after).decode("utf8")
def remove_method(source: str, function: FunctionToOptimize, analyzer: JavaAnalyzer | None = None) -> str:
"""Remove a method from source code.
Args:
source: The source code.
function: FunctionToOptimize identifying the method to remove.
analyzer: Optional JavaAnalyzer instance.
Returns:
Source code with method removed.
"""
analyzer = analyzer or get_java_analyzer()
func_name = function.function_name
# Find the method
methods = analyzer.find_methods(source)
target_method = None
for method in methods:
if method.name == func_name:
if function.class_name is None or method.class_name == function.class_name:
target_method = method
break
if not target_method:
logger.error("Could not find method %s", func_name)
return source
# Determine removal range (include Javadoc)
start_line = target_method.javadoc_start_line or target_method.start_line
end_line = target_method.end_line
lines = source.splitlines(keepends=True)
# Remove the method lines
before = lines[: start_line - 1]
after = lines[end_line:]
return "".join(before) + "".join(after)
def remove_test_functions(
test_source: str, functions_to_remove: list[str], analyzer: JavaAnalyzer | None = None
) -> str:
"""Remove specific test functions from test source code.
Args:
test_source: Test source code.
functions_to_remove: List of function names to remove.
analyzer: Optional JavaAnalyzer instance.
Returns:
Test source code with specified functions removed.
"""
analyzer = analyzer or get_java_analyzer()
# Find all methods
methods = analyzer.find_methods(test_source)
# Sort by start line in reverse order (remove from end first)
methods_to_remove = [m for m in methods if m.name in functions_to_remove]
methods_to_remove.sort(key=lambda m: m.start_line, reverse=True)
result = test_source
for method in methods_to_remove:
# Create a FunctionToOptimize for removal
func_info = FunctionToOptimize(
function_name=method.name,
file_path=Path("temp.java"),
starting_line=method.start_line,
ending_line=method.end_line,
parents=[],
is_method=True,
language="java",
)
result = remove_method(result, func_info, analyzer)
return result
def add_runtime_comments(
test_source: str,
original_runtimes: dict[str, int],
optimized_runtimes: dict[str, int],
analyzer: JavaAnalyzer | None = None,
) -> str:
"""Add inline runtime performance comments next to function calls.
Runtime keys have format "ClassName.methodName#L{line}" where the line number
refers to the 1-indexed line in the stripped source. For each matching line,
an inline comment like "// 2.89ms -> 26.2us (10,948% faster)" is appended.
Args:
test_source: Test source code to annotate.
original_runtimes: Map of invocation IDs to original runtimes (ns).
optimized_runtimes: Map of invocation IDs to optimized runtimes (ns).
analyzer: Optional JavaAnalyzer instance.
Returns:
Test source code with inline runtime comments added.
"""
from codeflash.code_utils.time_utils import format_runtime_comment
if not original_runtimes or not optimized_runtimes:
return test_source
# Extract class names declared in this test source to filter runtime keys.
# Only annotations whose "ClassName.method" prefix references a class in this file should be applied.
source_class_names: set[str] = set()
for source_line in test_source.splitlines():
stripped_line = source_line.strip()
if stripped_line.startswith(("public class ", "class ")):
# Extract class name: "public class FooTest {" -> "FooTest"
parts = stripped_line.split()
class_idx = parts.index("class") + 1 if "class" in parts else -1
if 0 < class_idx < len(parts):
class_name = parts[class_idx].rstrip("{").strip()
if class_name:
source_class_names.add(class_name)
# Build a map of line_number -> (original_ns, optimized_ns) from runtime keys.
# Keys look like "ClassName.methodName#L15" — extract the line number after "#L".
# Only include keys whose class name matches a class declared in this source file.
line_runtimes: dict[int, tuple[int, int]] = {}
for key in original_runtimes:
if "#L" not in key:
continue
prefix, line_part = key.split("#L", 1)
# Filter by class name: prefix is "ClassName.methodName", extract the class
if source_class_names:
key_class = prefix.split(".")[0] if "." in prefix else prefix
if key_class not in source_class_names:
continue
try:
line_num = int(line_part)
except ValueError:
continue
orig_ns = original_runtimes[key]
opt_ns = optimized_runtimes.get(key, orig_ns)
if orig_ns > 0:
if line_num in line_runtimes:
# Sum runtimes for multiple invocations on the same line
prev_orig, prev_opt = line_runtimes[line_num]
line_runtimes[line_num] = (prev_orig + orig_ns, prev_opt + opt_ns)
else:
line_runtimes[line_num] = (orig_ns, opt_ns)
if not line_runtimes:
return test_source
# Annotate lines (1-indexed)
lines = test_source.splitlines(keepends=True)
for line_num, (orig_ns, opt_ns) in line_runtimes.items():
idx = line_num - 1 # convert to 0-indexed
if idx < 0 or idx >= len(lines):
continue
comment = format_runtime_comment(orig_ns, opt_ns, comment_prefix="//")
line = lines[idx]
# Strip trailing newline, append comment, restore newline
stripped = line.rstrip("\n\r")
trailing = line[len(stripped) :]
lines[idx] = f"{stripped} {comment}{trailing}"
return "".join(lines)