Skip to content

Commit f061ce1

Browse files
committed
DispatchAny
1 parent a6d64d4 commit f061ce1

6 files changed

Lines changed: 210 additions & 69 deletions

File tree

prebindgen-ext/src/jni/jni_kotlin_ext.rs

Lines changed: 79 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -455,8 +455,9 @@ fn render_wrapper_fn(
455455
mode: ParamMode,
456456
}
457457
enum ParamMode {
458-
Borrow, // &T opaque-handle → withPtr
459-
Consume, // T opaque-handle → consume
458+
Borrow, // &T opaque-handle → withPtr
459+
Consume, // T opaque-handle → consume
460+
DispatchAny, // `impl Into<T>` (Kotlin `Any`) → runtime: NativeHandle.withPtr or pass-through
460461
PassThrough,
461462
}
462463

@@ -472,17 +473,9 @@ fn render_wrapper_fn(
472473
let entry = registry.input_entry(arg_ty)?;
473474
let is_opaque = converter_returns_owned_object(&entry.function.sig.output);
474475

475-
let mode = if is_opaque {
476-
if matches!(arg_ty, syn::Type::Reference(_)) {
477-
ParamMode::Borrow
478-
} else {
479-
ParamMode::Consume
480-
}
476+
let (kt_type_raw, optional) = if is_opaque {
477+
("NativeHandle".to_string(), false)
481478
} else {
482-
ParamMode::PassThrough
483-
};
484-
485-
let (kt_type, optional) = if matches!(mode, ParamMode::PassThrough) {
486479
// Look up the Kotlin type via the merged type map; fall
487480
// back to deriving from the wire type when the param
488481
// isn't pre-registered (e.g. `impl Into<T>` shapes wired
@@ -497,11 +490,26 @@ fn render_wrapper_fn(
497490
.or_else(|| kotlin_for_wire(&entry.destination))?;
498491
let opt = is_option_type(arg_ty);
499492
(kt, opt)
493+
};
494+
495+
// Mode: opaque → Borrow/Consume by Rust syntactic shape.
496+
// Non-opaque → `Any` triggers DispatchAny (`impl Into<T>` wire
497+
// = JObject; runtime might be a NativeHandle that we should
498+
// unwrap under withPtr to close the race window). Everything
499+
// else (primitives, callbacks, data classes) passes through.
500+
let mode = if is_opaque {
501+
if matches!(arg_ty, syn::Type::Reference(_)) {
502+
ParamMode::Borrow
503+
} else {
504+
ParamMode::Consume
505+
}
506+
} else if kt_type_raw == "Any" {
507+
ParamMode::DispatchAny
500508
} else {
501-
("NativeHandle".to_string(), false)
509+
ParamMode::PassThrough
502510
};
503511

504-
let short = register_fqn(&kt_type, imports);
512+
let short = register_fqn(&kt_type_raw, imports);
505513
let suffix = if optional { "?" } else { "" };
506514
params.push(Param {
507515
kt_name: name,
@@ -513,63 +521,81 @@ fn render_wrapper_fn(
513521
// Return type: peel ZResult<...>; detect opaque-handle return.
514522
let (kt_return, return_is_opaque) = classify_return(&f.sig.output, registry, kotlin_types, imports)?;
515523

516-
// Body: nest withPtr/consume in declaration order.
517-
let mut call = format!("JNINative.{jni_call}(");
518-
for (i, p) in params.iter().enumerate() {
519-
if i > 0 { call.push_str(", "); }
520-
call.push_str(&p.kt_name);
521-
}
522-
call.push(')');
524+
// Helper: build the JNINative call for a given DispatchAny "unwrap mask".
525+
// mask bit k = 1 means dispatch_indices[k] is unwrapped (use `<name>_ptr`).
526+
let dispatch_indices: Vec<usize> = params
527+
.iter()
528+
.enumerate()
529+
.filter_map(|(i, p)| matches!(p.mode, ParamMode::DispatchAny).then_some(i))
530+
.collect();
523531

524-
// Build innermost expression, wrapping in NativeHandle() if return is opaque.
525-
let mut body_expr = if return_is_opaque {
526-
format!("NativeHandle({call})")
527-
} else {
532+
let build_call = |mask: u32| -> String {
533+
let mut args: Vec<String> = Vec::with_capacity(params.len());
534+
for (i, p) in params.iter().enumerate() {
535+
let arg = match p.mode {
536+
ParamMode::Borrow | ParamMode::Consume => format!("{}_ptr", p.kt_name),
537+
ParamMode::DispatchAny => {
538+
let pos = dispatch_indices.iter().position(|&di| di == i).unwrap();
539+
if (mask >> pos) & 1 == 1 {
540+
format!("{}_ptr", p.kt_name)
541+
} else {
542+
p.kt_name.clone()
543+
}
544+
}
545+
ParamMode::PassThrough => p.kt_name.clone(),
546+
};
547+
args.push(arg);
548+
}
549+
let mut call = format!("JNINative.{jni_call}({})", args.join(", "));
550+
if return_is_opaque {
551+
call = format!("NativeHandle({call})");
552+
}
528553
call
529554
};
530555

556+
// Build the DispatchAny decision tree. At each level we pick one
557+
// dispatch_indices[level]: `if (p is NativeHandle) p.withPtr { ... } else { ... }`.
558+
// The base case (level == n) emits the JNINative call with the
559+
// accumulated unwrap mask. Branches recursively split.
560+
fn build_tree(
561+
level: usize,
562+
mask: u32,
563+
dispatch_indices: &[usize],
564+
params: &[(String, /*placeholder for type*/ ())],
565+
build_call: &dyn Fn(u32) -> String,
566+
) -> String {
567+
if level == dispatch_indices.len() {
568+
return build_call(mask);
569+
}
570+
let name = &params[dispatch_indices[level]].0;
571+
let with_branch = build_tree(level + 1, mask | (1 << level), dispatch_indices, params, build_call);
572+
let else_branch = build_tree(level + 1, mask, dispatch_indices, params, build_call);
573+
format!(
574+
"if ({name} is NativeHandle) {name}.withPtr {{ {name}_ptr ->\n {with_branch}\n}} else {{\n {else_branch}\n}}"
575+
)
576+
}
577+
578+
let param_names_for_tree: Vec<(String, ())> = params.iter().map(|p| (p.kt_name.clone(), ())).collect();
579+
let mut body_expr = build_tree(0, 0, &dispatch_indices, &param_names_for_tree, &build_call);
580+
531581
// Wrap with nested withPtr/consume from innermost to outermost.
532-
let mut indent = String::new();
533582
for p in params.iter().rev() {
534583
match p.mode {
535584
ParamMode::Borrow => {
536585
body_expr = format!(
537-
"{name}.withPtr {{ {name}_ptr ->\n{indent} {expr}\n{indent}}}",
586+
"{name}.withPtr {{ {name}_ptr ->\n {expr}\n}}",
538587
name = p.kt_name,
539-
indent = indent,
540588
expr = body_expr,
541589
);
542590
}
543591
ParamMode::Consume => {
544592
body_expr = format!(
545-
"{name}.consume {{ {name}_ptr ->\n{indent} {expr}\n{indent}}}",
593+
"{name}.consume {{ {name}_ptr ->\n {expr}\n}}",
546594
name = p.kt_name,
547-
indent = indent,
548595
expr = body_expr,
549596
);
550597
}
551-
ParamMode::PassThrough => {}
552-
}
553-
// Indentation isn't strictly needed for correctness; keep flat.
554-
}
555-
556-
// Replace `<name>` with `<name>_ptr` inside the call args for opaque params.
557-
// (We built `call` with `name`, but for opaque params the underlying
558-
// JNI fn takes the raw Long; rewrite arg references accordingly.)
559-
let mut fixed = body_expr;
560-
for p in &params {
561-
if matches!(p.mode, ParamMode::Borrow | ParamMode::Consume) {
562-
// Replace the bare parameter name reference inside the JNINative call
563-
// with `<name>_ptr`. Word-boundary-safe substitution.
564-
let needle = format!(", {}", p.kt_name);
565-
let repl = format!(", {}_ptr", p.kt_name);
566-
fixed = fixed.replace(&needle, &repl);
567-
let head_needle = format!("({}", p.kt_name);
568-
let head_repl = format!("({}_ptr", p.kt_name);
569-
fixed = fixed.replace(&head_needle, &head_repl);
570-
let solo_needle = format!("({})", p.kt_name);
571-
let solo_repl = format!("({}_ptr)", p.kt_name);
572-
fixed = fixed.replace(&solo_needle, &solo_repl);
598+
ParamMode::DispatchAny | ParamMode::PassThrough => {}
573599
}
574600
}
575601

@@ -581,7 +607,7 @@ fn render_wrapper_fn(
581607
let _ = write!(out, ": {kt_return}");
582608
}
583609
let _ = writeln!(out, " =");
584-
let _ = writeln!(out, " {fixed}");
610+
let _ = writeln!(out, " {body_expr}");
585611
Some(out)
586612
}
587613

zenoh-jni-runtime/src/commonMain/kotlin/io/zenoh/jni/JNIKeyExpr.kt

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,13 @@ import io.zenoh.exceptions.ZError
3030

3131
/**
3232
* Pick the declared handle if present, else the raw string. Returns
33-
* a JVM `Object` (boxed `java.lang.Long` or `java.lang.String`) which
34-
* the native dispatching converter resolves at runtime.
35-
*
36-
* Takes a [NativeHandle] (not a raw `Long`) so callers don't see the
37-
* inner pointer. The peek is unlocked — the borrow's Rust side runs
38-
* `Arc::increment_strong_count` and a concurrent close of this
39-
* keyExpr is the same race window the broader JNI layer already
40-
* tolerates for borrow-style calls.
33+
* the `NativeHandle` (subtype of `Any`) when a handle is declared,
34+
* else the `String` (also `Any`). The generator-emitted wrappers in
35+
* [JNIWrappers] runtime-dispatch on `is NativeHandle` and acquire
36+
* the read lock via `withPtr` before crossing the JNI boundary, so
37+
* no `Long` ever escapes outside a held lock.
4138
*/
42-
fun keyExprArg(handle: NativeHandle?, str: String): Any =
43-
handle?.peek()?.takeIf { it != 0L } ?: str
39+
fun keyExprArg(handle: NativeHandle?, str: String): Any = handle ?: str
4440

4541
@Throws(ZError::class)
4642
fun keyExprTryFrom(keyExpr: String): String = JNIWrappers.tryFrom(keyExpr)

zenoh-jni-runtime/src/commonMain/kotlin/io/zenoh/jni/JNIQuerier.kt

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,13 @@ public class JNIQuerier(initialPtr: Long) : NativeHandle(initialPtr) {
3737
payload: ByteArray?,
3838
encoding: JNIEncoding?,
3939
) = withPtr { ptr ->
40-
getViaJNI(ptr, keyExprArg(keyExprHandle, keyExprString), parameters, callback, onClose, attachmentBytes, payload, encoding)
40+
if (keyExprHandle != null) {
41+
keyExprHandle.withPtr { kePtr ->
42+
getViaJNI(ptr, kePtr, parameters, callback, onClose, attachmentBytes, payload, encoding)
43+
}
44+
} else {
45+
getViaJNI(ptr, keyExprString, parameters, callback, onClose, attachmentBytes, payload, encoding)
46+
}
4147
}
4248

4349
@Throws(ZError::class)

zenoh-jni-runtime/src/commonMain/kotlin/io/zenoh/jni/JNIQuery.kt

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,13 @@ public class JNIQuery(initialPtr: Long) : NativeHandle(initialPtr) {
3131
attachment: ByteArray?,
3232
qosExpress: Boolean,
3333
) = withPtr { ptr ->
34-
replySuccessViaJNI(ptr, keyExprArg(keyExprHandle, keyExprString), payload, encoding, timestampEnabled, timestampNtp64, attachment, qosExpress)
34+
if (keyExprHandle != null) {
35+
keyExprHandle.withPtr { kePtr ->
36+
replySuccessViaJNI(ptr, kePtr, payload, encoding, timestampEnabled, timestampNtp64, attachment, qosExpress)
37+
}
38+
} else {
39+
replySuccessViaJNI(ptr, keyExprString, payload, encoding, timestampEnabled, timestampNtp64, attachment, qosExpress)
40+
}
3541
}
3642

3743
@Throws(ZError::class)
@@ -48,7 +54,13 @@ public class JNIQuery(initialPtr: Long) : NativeHandle(initialPtr) {
4854
attachment: ByteArray?,
4955
qosExpress: Boolean,
5056
) = withPtr { ptr ->
51-
replyDeleteViaJNI(ptr, keyExprArg(keyExprHandle, keyExprString), timestampEnabled, timestampNtp64, attachment, qosExpress)
57+
if (keyExprHandle != null) {
58+
keyExprHandle.withPtr { kePtr ->
59+
replyDeleteViaJNI(ptr, kePtr, timestampEnabled, timestampNtp64, attachment, qosExpress)
60+
}
61+
} else {
62+
replyDeleteViaJNI(ptr, keyExprString, timestampEnabled, timestampNtp64, attachment, qosExpress)
63+
}
5264
}
5365

5466
fun close() = close { freePtrViaJNI(it) }

zenoh-jni-runtime/src/commonMain/kotlin/io/zenoh/jni/JNISession.kt

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,16 @@ public class JNISession(initialPtr: Long) : NativeHandle(initialPtr) {
263263
@Throws(ZError::class)
264264
fun declareLivelinessToken(keyExprHandle: NativeHandle?, keyExprString: String): JNILivelinessToken =
265265
withPtr { ptr ->
266-
JNILivelinessToken(declareLivelinessTokenViaJNI(ptr, keyExprArg(keyExprHandle, keyExprString)))
266+
// Mirror JNIWrappers' impl-Into dispatch: hold the keyExpr's
267+
// read lock for the JNI call when a declared handle is in
268+
// play; otherwise pass the validated string.
269+
if (keyExprHandle != null) {
270+
keyExprHandle.withPtr { kePtr ->
271+
JNILivelinessToken(declareLivelinessTokenViaJNI(ptr, kePtr))
272+
}
273+
} else {
274+
JNILivelinessToken(declareLivelinessTokenViaJNI(ptr, keyExprString))
275+
}
267276
}
268277

269278
@Throws(ZError::class)
@@ -277,7 +286,13 @@ public class JNISession(initialPtr: Long) : NativeHandle(initialPtr) {
277286
history: Boolean,
278287
onClose: JNIOnCloseCallback,
279288
): JNISubscriber = withPtr { ptr ->
280-
JNISubscriber(declareLivelinessSubscriberViaJNI(ptr, keyExprArg(keyExprHandle, keyExprString), callback, history, onClose))
289+
if (keyExprHandle != null) {
290+
keyExprHandle.withPtr { kePtr ->
291+
JNISubscriber(declareLivelinessSubscriberViaJNI(ptr, kePtr, callback, history, onClose))
292+
}
293+
} else {
294+
JNISubscriber(declareLivelinessSubscriberViaJNI(ptr, keyExprString, callback, history, onClose))
295+
}
281296
}
282297

283298
@Throws(ZError::class)
@@ -297,7 +312,13 @@ public class JNISession(initialPtr: Long) : NativeHandle(initialPtr) {
297312
timeoutMs: Long,
298313
onClose: JNIOnCloseCallback,
299314
) = withPtr { ptr ->
300-
livelinessGetViaJNI(ptr, keyExprArg(keyExprHandle, keyExprString), callback, timeoutMs, onClose)
315+
if (keyExprHandle != null) {
316+
keyExprHandle.withPtr { kePtr ->
317+
livelinessGetViaJNI(ptr, kePtr, callback, timeoutMs, onClose)
318+
}
319+
} else {
320+
livelinessGetViaJNI(ptr, keyExprString, callback, timeoutMs, onClose)
321+
}
301322
}
302323

303324
@Throws(ZError::class)

0 commit comments

Comments
 (0)