3333import org .openrewrite .java .tree .TypeUtils ;
3434
3535import java .util .ArrayList ;
36+ import java .util .Collections ;
3637import java .util .HashMap ;
3738import java .util .HashSet ;
3839import java .util .List ;
@@ -73,17 +74,17 @@ public J visitNewClass(J.NewClass newClass, ExecutionContext ctx) {
7374
7475 // Prose-pattern: see if visitBlock decided this initializer should be wrapped
7576 // with `new HashMap<>(Map.of(..))` or `new HashMap<>(Map.ofEntries(..))`.
76- Map <UUID , List <Expression >> rewrites = getCursor ().getNearestMessage (PROSE_REWRITES_KEY );
77+ Map <UUID , List <J . MethodInvocation >> rewrites = getCursor ().getNearestMessage (PROSE_REWRITES_KEY );
7778 if (rewrites != null ) {
78- List <Expression > proseArgs = rewrites .get (n .getId ());
79- if (proseArgs != null ) {
80- // proseArgs is [k1, v1, k2, v2, ...]
81- int pairCount = proseArgs .size () / 2 ;
82- boolean useEntries = pairCount > 10 ;
79+ List <J .MethodInvocation > puts = rewrites .get (n .getId ());
80+ if (puts != null ) {
81+ boolean useEntries = puts .size () > 10 ;
82+ List <Expression > args = new ArrayList <>();
8383 StringJoiner inner = useEntries ?
8484 new StringJoiner (", " , "Map.ofEntries(" , ")" ) :
8585 new StringJoiner (", " , "Map.of(" , ")" );
86- for (int p = 0 ; p < pairCount ; p ++) {
86+ for (J .MethodInvocation put : puts ) {
87+ args .addAll (put .getArguments ());
8788 if (useEntries ) {
8889 inner .add ("Map.entry(#{any()}, #{any()})" );
8990 } else {
@@ -93,11 +94,14 @@ public J visitNewClass(J.NewClass newClass, ExecutionContext ctx) {
9394 }
9495 String src = "new HashMap<>(" + inner + ")" ;
9596 maybeAddImport ("java.util.Map" );
96- return JavaTemplate .builder (src )
97+ J applied = JavaTemplate .builder (src )
9798 .contextSensitive ()
9899 .imports ("java.util.HashMap" , "java.util.Map" )
99100 .build ()
100- .apply (updateCursor (n ), n .getCoordinates ().replace (), proseArgs .toArray ());
101+ .apply (updateCursor (n ), n .getCoordinates ().replace (), args .toArray ());
102+ // Reattach each put's prefix so the entries land one-per-line and any
103+ // leading comments survive, then autoformat to nest the indentation.
104+ return autoFormat (reattachPairPrefixes (applied , puts , useEntries ), ctx );
101105 }
102106 }
103107
@@ -160,9 +164,39 @@ public J visitNewClass(J.NewClass newClass, ExecutionContext ctx) {
160164 return n ;
161165 }
162166
167+ /**
168+ * Re-applies the absorbed put statements' prefixes to the generated
169+ * {@code new HashMap<>(Map.of(..))} / {@code new HashMap<>(Map.ofEntries(..))} so each
170+ * pair keeps its own line and any leading comments. {@code puts} holds one invocation
171+ * per pair, in order; for {@code Map.of} every other argument (the keys, at even
172+ * indices) takes a prefix, while for {@code Map.ofEntries} every {@code Map.entry(..)}
173+ * argument does.
174+ */
175+ private J reattachPairPrefixes (J applied , List <J .MethodInvocation > puts , boolean useEntries ) {
176+ if (!(applied instanceof J .NewClass )) {
177+ return applied ;
178+ }
179+ J .NewClass nc = (J .NewClass ) applied ;
180+ if (nc .getArguments ().size () != 1 || !(nc .getArguments ().get (0 ) instanceof J .MethodInvocation )) {
181+ return applied ;
182+ }
183+ J .MethodInvocation mapCall = (J .MethodInvocation ) nc .getArguments ().get (0 );
184+ List <Expression > mapArgs = mapCall .getArguments ();
185+ int step = useEntries ? 1 : 2 ;
186+ List <Expression > withPrefixes = new ArrayList <>(mapArgs .size ());
187+ for (int i = 0 ; i < mapArgs .size (); i ++) {
188+ Expression arg = mapArgs .get (i );
189+ if (i % step == 0 ) {
190+ arg = arg .withPrefix (puts .get (i / step ).getPrefix ());
191+ }
192+ withPrefixes .add (arg );
193+ }
194+ return nc .withArguments (Collections .singletonList (mapCall .withArguments (withPrefixes )));
195+ }
196+
163197 @ Override
164198 public J visitBlock (J .Block block , ExecutionContext ctx ) {
165- Map <UUID , List <Expression >> rewrites = new HashMap <>();
199+ Map <UUID , List <J . MethodInvocation >> rewrites = new HashMap <>();
166200 Set <UUID > absorbedPutIds = new HashSet <>();
167201 identifyProseRewrites (block , rewrites , absorbedPutIds );
168202
@@ -185,12 +219,12 @@ public J visitBlock(J.Block block, ExecutionContext ctx) {
185219 * ...
186220 * </pre>
187221 * For each such sequence with at least two puts, record
188- * (initializer UUID, [k1, v1, k2, v2, ...]) in {@code rewrites} and the
189- * absorbed put statement UUIDs in {@code absorbedPutIds}.
222+ * (initializer UUID, the absorbed {@code put(..)} invocations in order) in
223+ * {@code rewrites} and the absorbed put statement UUIDs in {@code absorbedPutIds}.
190224 */
191225 private void identifyProseRewrites (
192226 J .Block block ,
193- Map <UUID , List <Expression >> rewrites ,
227+ Map <UUID , List <J . MethodInvocation >> rewrites ,
194228 Set <UUID > absorbedPutIds ) {
195229 List <Statement > stmts = block .getStatements ();
196230 int i = 0 ;
@@ -208,9 +242,8 @@ private void identifyProseRewrites(
208242 }
209243 J .NewClass initializer = (J .NewClass ) decl .getVariables ().get (0 ).getInitializer ();
210244
211- List <Expression > args = new ArrayList <>();
245+ List <J . MethodInvocation > puts = new ArrayList <>();
212246 List <UUID > absorbedHere = new ArrayList <>();
213- int pairs = 0 ;
214247 int j = i + 1 ;
215248 while (j < stmts .size ()) {
216249 Statement next = stmts .get (j );
@@ -222,13 +255,12 @@ private void identifyProseRewrites(
222255 expressionReferences (kv .get (1 ), targetName )) {
223256 break ;
224257 }
225- args . addAll ( kv );
258+ puts . add (( J . MethodInvocation ) next );
226259 absorbedHere .add (next .getId ());
227- pairs ++;
228260 j ++;
229261 }
230- if (pairs >= 2 && initializer != null ) {
231- rewrites .put (initializer .getId (), args );
262+ if (puts . size () >= 2 && initializer != null ) {
263+ rewrites .put (initializer .getId (), puts );
232264 absorbedPutIds .addAll (absorbedHere );
233265 i = j ;
234266 } else {
0 commit comments