Skip to content

Commit 4e38d0e

Browse files
committed
x
1 parent 7f564a4 commit 4e38d0e

1 file changed

Lines changed: 140 additions & 62 deletions

File tree

rust/src/string.rs

Lines changed: 140 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,139 @@ unsafe fn write_replaced_bytes(
295295
}
296296
}
297297

298+
enum ReplaceBytesFlow {
299+
Return(bool),
300+
Proceed {
301+
current_len: c_uint,
302+
hay_len: usize,
303+
needle_len: usize,
304+
replace_len: usize,
305+
},
306+
}
307+
308+
struct ReplaceBytesScratch {
309+
ptr: *mut c_char,
310+
written: usize,
311+
new_len: c_uint,
312+
}
313+
314+
#[inline]
315+
fn validate_replace_bytes(
316+
buffer: *mut c_char,
317+
current_len: c_uint,
318+
find: *const c_char,
319+
find_len: c_uint,
320+
replace: *const c_char,
321+
replace_len: c_uint,
322+
) -> ReplaceBytesFlow {
323+
if current_len == 0 || find_len == 0 {
324+
return ReplaceBytesFlow::Return(true);
325+
}
326+
if buffer.is_null() || find.is_null() {
327+
return ReplaceBytesFlow::Return(false);
328+
}
329+
if replace.is_null() && replace_len != 0 {
330+
return ReplaceBytesFlow::Return(false);
331+
}
332+
333+
let hay_len = current_len as usize;
334+
let needle_len = find_len as usize;
335+
if needle_len > hay_len {
336+
return ReplaceBytesFlow::Return(true);
337+
}
338+
339+
ReplaceBytesFlow::Proceed {
340+
current_len,
341+
hay_len,
342+
needle_len,
343+
replace_len: replace_len as usize,
344+
}
345+
}
346+
347+
#[inline]
348+
fn prepare_replace_scratch(
349+
buffer: *const c_char,
350+
find: *const c_char,
351+
replace: *const c_char,
352+
flow: &ReplaceBytesFlow,
353+
) -> Result<Option<ReplaceBytesScratch>, ()> {
354+
let ReplaceBytesFlow::Proceed {
355+
current_len,
356+
hay_len,
357+
needle_len,
358+
replace_len,
359+
} = flow
360+
else {
361+
return Err(());
362+
};
363+
364+
// SAFETY: `flow` is validated by `validate_replace_bytes`, and pointers/lengths here are coherent.
365+
unsafe {
366+
let match_count = count_non_overlapping_matches(buffer, *hay_len, find, *needle_len);
367+
if match_count == 0 {
368+
return Ok(None);
369+
}
370+
371+
let Some(new_len) = compute_replaced_len(
372+
*current_len,
373+
*needle_len as c_uint,
374+
*replace_len as c_uint,
375+
match_count,
376+
) else {
377+
return Err(());
378+
};
379+
380+
let scratch_size = (new_len as usize).saturating_add(1);
381+
let scratch_ptr = malloc(scratch_size) as *mut c_char;
382+
if scratch_ptr.is_null() {
383+
return Err(());
384+
}
385+
386+
let written = write_replaced_bytes(
387+
buffer,
388+
*hay_len,
389+
find,
390+
*needle_len,
391+
replace,
392+
*replace_len,
393+
scratch_ptr,
394+
scratch_size,
395+
);
396+
397+
Ok(Some(ReplaceBytesScratch {
398+
ptr: scratch_ptr,
399+
written,
400+
new_len,
401+
}))
402+
}
403+
}
404+
405+
#[inline]
406+
fn commit_replace_scratch(
407+
buffer: &mut *mut c_char,
408+
capacity: &mut c_uint,
409+
len: &mut c_uint,
410+
scratch: ReplaceBytesScratch,
411+
) -> bool {
412+
if !ensure_capacity(buffer, capacity, len, scratch.new_len) {
413+
// SAFETY: pointer came from malloc in `prepare_replace_scratch`.
414+
unsafe { free(scratch.ptr.cast()) }
415+
return false;
416+
}
417+
418+
// SAFETY: destination has enough capacity; source points to scratch allocation.
419+
unsafe {
420+
ptr::copy_nonoverlapping(
421+
scratch.ptr.cast::<u8>(),
422+
(*buffer).cast::<u8>(),
423+
scratch.written + 1,
424+
);
425+
free(scratch.ptr.cast());
426+
}
427+
*len = scratch.new_len;
428+
true
429+
}
430+
298431
#[inline]
299432
fn ensure_capacity(
300433
buffer: &mut *mut c_char,
@@ -1133,71 +1266,16 @@ pub extern "C" fn arduino_string_replace_bytes(
11331266
return false;
11341267
};
11351268

1136-
let current_len = *len;
1137-
if current_len == 0 || find_len == 0 {
1138-
return true;
1139-
}
1140-
if (*buffer).is_null() || find.is_null() {
1141-
return false;
1142-
}
1143-
if replace.is_null() && replace_len != 0 {
1144-
return false;
1269+
let flow = validate_replace_bytes(*buffer, *len, find, find_len, replace, replace_len);
1270+
if let ReplaceBytesFlow::Return(ok) = &flow {
1271+
return *ok;
11451272
}
11461273

1147-
let hay_len = current_len as usize;
1148-
let needle_len = find_len as usize;
1149-
if needle_len > hay_len {
1150-
return true;
1274+
match prepare_replace_scratch(*buffer, find, replace, &flow) {
1275+
Ok(None) => true,
1276+
Ok(Some(scratch)) => commit_replace_scratch(buffer, capacity, len, scratch),
1277+
Err(()) => false,
11511278
}
1152-
1153-
// SAFETY: inputs were validated above.
1154-
let match_count = unsafe { count_non_overlapping_matches(*buffer, hay_len, find, needle_len) };
1155-
if match_count == 0 {
1156-
return true;
1157-
}
1158-
1159-
let Some(new_len) = compute_replaced_len(current_len, find_len, replace_len, match_count)
1160-
else {
1161-
return false;
1162-
};
1163-
let tmp_size = (new_len as usize).saturating_add(1);
1164-
// SAFETY: malloc returns either null or writable memory.
1165-
let tmp = unsafe { malloc(tmp_size) as *mut c_char };
1166-
if tmp.is_null() {
1167-
return false;
1168-
}
1169-
1170-
// SAFETY: inputs and destination capacity are validated above.
1171-
let written = unsafe {
1172-
write_replaced_bytes(
1173-
*buffer,
1174-
hay_len,
1175-
find,
1176-
needle_len,
1177-
replace,
1178-
replace_len as usize,
1179-
tmp,
1180-
tmp_size,
1181-
)
1182-
};
1183-
1184-
let ok = if !ensure_capacity(buffer, capacity, len, new_len) {
1185-
false
1186-
} else {
1187-
// SAFETY: destination has at least written+1 bytes.
1188-
unsafe {
1189-
ptr::copy_nonoverlapping(tmp.cast::<u8>(), (*buffer).cast::<u8>(), written + 1);
1190-
}
1191-
*len = new_len;
1192-
true
1193-
};
1194-
1195-
// SAFETY: tmp was allocated above with malloc.
1196-
unsafe {
1197-
free(tmp.cast());
1198-
}
1199-
1200-
ok
12011279
}
12021280

12031281
#[unsafe(no_mangle)]

0 commit comments

Comments
 (0)