diff --git a/.envrc b/.envrc new file mode 100644 index 0000000..3550a30 --- /dev/null +++ b/.envrc @@ -0,0 +1 @@ +use flake diff --git a/.gitignore b/.gitignore index b89558b..22c127e 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,6 @@ Cargo.lock # These are backup files generated by rustfmt **/*.rs.bk *~ + +# flake.lock +.direnv \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index 3b1e074..0c35f09 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ repository = "https://github.com/openrr/rrt" # Note: num-traits is public dependency. [dependencies] +derive_more = { version = "0.99.17", default-features = false, features = ["display", "error"] } kdtree = "0.7" num-traits = "0.2" rand = "0.8" diff --git a/examples/collision_avoid.rs b/examples/collision_avoid.rs index 23e1762..92d74da 100644 --- a/examples/collision_avoid.rs +++ b/examples/collision_avoid.rs @@ -97,7 +97,7 @@ fn main() { let mut index = 0; while window.render() { if index == path.len() { - path = rrt::dual_rrt_connect( + path = rrt::rrt::dual_rrt_connect( &start, &goal, |x: &[f64]| p.is_feasible(x), @@ -106,7 +106,7 @@ fn main() { 1000, ) .unwrap(); - rrt::smooth_path(&mut path, |x: &[f64]| p.is_feasible(x), 0.05, 100); + rrt::rrt::smooth_path(&mut path, |x: &[f64]| p.is_feasible(x), 0.05, 100); index = 0; } let point = &path[index % path.len()]; diff --git a/flake.lock b/flake.lock new file mode 100644 index 0000000..ccf1ac7 --- /dev/null +++ b/flake.lock @@ -0,0 +1,85 @@ +{ + "nodes": { + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1710146030, + "narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1714253743, + "narHash": "sha256-mdTQw2XlariysyScCv2tTE45QSU9v/ezLcHJ22f0Nxc=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "58a1abdbae3217ca6b702f03d3b35125d88a2994", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixos-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "flake-utils": "flake-utils", + "nixpkgs": "nixpkgs", + "rust-overlay": "rust-overlay" + } + }, + "rust-overlay": { + "inputs": { + "flake-utils": [ + "flake-utils" + ], + "nixpkgs": [ + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1714443211, + "narHash": "sha256-lKTA3XqRo4aVgkyTSCtpcALpGXdmkilHTtN00eRg0QU=", + "owner": "oxalica", + "repo": "rust-overlay", + "rev": "ce35c36f58f82cee6ec959e0d44c587d64281b6f", + "type": "github" + }, + "original": { + "owner": "oxalica", + "repo": "rust-overlay", + "type": "github" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000..669556c --- /dev/null +++ b/flake.nix @@ -0,0 +1,114 @@ +{ + description = "gbp-rs"; + inputs = { + # wgsl_analyzer.url = "github:wgsl-analyzer/wgsl-analyzer"; + nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable"; + flake-utils.url = "github:numtide/flake-utils"; + rust-overlay = { + url = "github:oxalica/rust-overlay"; + inputs = { + nixpkgs.follows = "nixpkgs"; + flake-utils.follows = "flake-utils"; + }; + }; + }; + + outputs = { + self, + nixpkgs, + rust-overlay, + flake-utils, + ... + } @ inputs: + inputs.flake-utils.lib.eachDefaultSystem (system: let + overlays = [(import rust-overlay)]; + pkgs = import inputs.nixpkgs {inherit system overlays;}; + rust-extensions = [ + "rust-src" + "rust-analyzer" + "llvm-tools-preview" # used with `cargo-pgo` + ]; + rust-targets = ["wasm32-unknown-unknown"]; + bevy-deps = with pkgs; [ + udev + alsa-lib + vulkan-loader + xorg.libX11 + xorg.libXcursor + xorg.libXi + xorg.libXrandr + libxkbcommon + wayland + egl-wayland + # wgsl-analyzer-pkgs.wgsl_analyzer + # wgsl_analyzer.packages.${system} + # wgsl_analyzer.outputs.packages.${system}.default + ]; + # wgsl-analyzer-pkgs = import inputs.wgsl_analyzer {inherit system;}; + cargo-subcommands = with pkgs; [ + cargo-bloat + cargo-expand + cargo-outdated + cargo-show-asm + cargo-make + cargo-modules + cargo-nextest + cargo-rr + cargo-udeps + cargo-watch + cargo-wizard + cargo-pgo + # cargo-tree + + # # cargo-profiler + # # cargo-feature + ]; + rust-deps = with pkgs; + [ + # rustup + taplo # TOML formatter and LSP + bacon + mold # A Modern Linker + clang # For linking + gdb # debugger + # lldb # debugger + rr # time-traveling debugger + ] + ++ cargo-subcommands; + in + with pkgs; { + formatter.${system} = pkgs.alejandra; + devShells.default = pkgs.mkShell rec { + nativeBuildInputs = with pkgs; [ + pkgs.pkg-config + ]; + buildInputs = + [ + # (rust-bin.stable.latest.default.override + # { + # extensions = rust-extensions; + # targets = rust-targets; + # }) + # (rust-bin.beta.latest.default.override { + # extensions = ["rust-src" "rust-analyzer"]; + # }) + ( + rust-bin.selectLatestNightlyWith (toolchain: + toolchain.default.override { + extensions = + rust-extensions + ++ [ + "rustc-codegen-cranelift-preview" + ]; + targets = ["wasm32-unknown-unknown"]; + }) + ) + just + gh + ] + ++ rust-deps ++ bevy-deps; + + LD_LIBRARY_PATH = lib.makeLibraryPath buildInputs; + }; + }); +} diff --git a/src/lib.rs b/src/lib.rs index 9c8f963..e218074 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,249 +17,5 @@ #![doc = include_str!("../README.md")] #![warn(missing_docs)] -use kdtree::distance::squared_euclidean; -use num_traits::float::Float; -use num_traits::identities::Zero; -use rand::distributions::{Distribution, Uniform}; -use std::fmt::Debug; -use std::mem; -use tracing::debug; - -#[derive(Debug)] -enum ExtendStatus { - Reached(usize), - Advanced(usize), - Trapped, -} - -/// Node that contains user data -#[derive(Debug, Clone)] -struct Node { - parent_index: Option, - data: T, -} - -impl Node { - fn new(data: T) -> Self { - Node { - parent_index: None, - data, - } - } -} - -/// RRT -#[derive(Debug)] -struct Tree -where - N: Float + Zero + Debug, -{ - kdtree: kdtree::KdTree>, - vertices: Vec>>, - name: &'static str, -} - -impl Tree -where - N: Float + Zero + Debug, -{ - fn new(name: &'static str, dim: usize) -> Self { - Tree { - kdtree: kdtree::KdTree::new(dim), - vertices: Vec::new(), - name, - } - } - fn add_vertex(&mut self, q: &[N]) -> usize { - let index = self.vertices.len(); - self.kdtree.add(q.to_vec(), index).unwrap(); - self.vertices.push(Node::new(q.to_vec())); - index - } - fn add_edge(&mut self, q1_index: usize, q2_index: usize) { - self.vertices[q2_index].parent_index = Some(q1_index); - } - fn get_nearest_index(&self, q: &[N]) -> usize { - *self.kdtree.nearest(q, 1, &squared_euclidean).unwrap()[0].1 - } - fn extend(&mut self, q_target: &[N], extend_length: N, is_free: &mut FF) -> ExtendStatus - where - FF: FnMut(&[N]) -> bool, - { - assert!(extend_length > N::zero()); - let nearest_index = self.get_nearest_index(q_target); - let nearest_q = &self.vertices[nearest_index].data; - let diff_dist = squared_euclidean(q_target, nearest_q).sqrt(); - let q_new = if diff_dist < extend_length { - q_target.to_vec() - } else { - nearest_q - .iter() - .zip(q_target) - .map(|(near, target)| *near + (*target - *near) * extend_length / diff_dist) - .collect::>() - }; - debug!("q_new={q_new:?}"); - if is_free(&q_new) { - let new_index = self.add_vertex(&q_new); - self.add_edge(nearest_index, new_index); - if squared_euclidean(&q_new, q_target).sqrt() < extend_length { - return ExtendStatus::Reached(new_index); - } - debug!("target = {q_target:?}"); - debug!("advanced to {q_target:?}"); - return ExtendStatus::Advanced(new_index); - } - ExtendStatus::Trapped - } - fn connect(&mut self, q_target: &[N], extend_length: N, is_free: &mut FF) -> ExtendStatus - where - FF: FnMut(&[N]) -> bool, - { - loop { - debug!("connecting...{q_target:?}"); - match self.extend(q_target, extend_length, is_free) { - ExtendStatus::Trapped => return ExtendStatus::Trapped, - ExtendStatus::Reached(index) => return ExtendStatus::Reached(index), - ExtendStatus::Advanced(_) => {} - }; - } - } - fn get_until_root(&self, index: usize) -> Vec> { - let mut nodes = Vec::new(); - let mut cur_index = index; - while let Some(parent_index) = self.vertices[cur_index].parent_index { - cur_index = parent_index; - nodes.push(self.vertices[cur_index].data.clone()) - } - nodes - } -} - -/// search the path from start to goal which is free, using random_sample function -pub fn dual_rrt_connect( - start: &[N], - goal: &[N], - mut is_free: FF, - random_sample: FR, - extend_length: N, - num_max_try: usize, -) -> Result>, String> -where - FF: FnMut(&[N]) -> bool, - FR: Fn() -> Vec, - N: Float + Debug, -{ - assert_eq!(start.len(), goal.len()); - let mut tree_a = Tree::new("start", start.len()); - let mut tree_b = Tree::new("goal", start.len()); - tree_a.add_vertex(start); - tree_b.add_vertex(goal); - for _ in 0..num_max_try { - debug!("tree_a = {:?}", tree_a.vertices.len()); - debug!("tree_b = {:?}", tree_b.vertices.len()); - let q_rand = random_sample(); - let extend_status = tree_a.extend(&q_rand, extend_length, &mut is_free); - match extend_status { - ExtendStatus::Trapped => {} - ExtendStatus::Advanced(new_index) | ExtendStatus::Reached(new_index) => { - let q_new = &tree_a.vertices[new_index].data; - if let ExtendStatus::Reached(reach_index) = - tree_b.connect(q_new, extend_length, &mut is_free) - { - let mut a_all = tree_a.get_until_root(new_index); - let mut b_all = tree_b.get_until_root(reach_index); - a_all.reverse(); - a_all.append(&mut b_all); - if tree_b.name == "start" { - a_all.reverse(); - } - return Ok(a_all); - } - } - } - mem::swap(&mut tree_a, &mut tree_b); - } - Err("failed".to_string()) -} - -/// select random two points, and try to connect. -pub fn smooth_path( - path: &mut Vec>, - mut is_free: FF, - extend_length: N, - num_max_try: usize, -) where - FF: FnMut(&[N]) -> bool, - N: Float + Debug, -{ - if path.len() < 3 { - return; - } - let mut rng = rand::thread_rng(); - for _ in 0..num_max_try { - let range1 = Uniform::new(0, path.len() - 2); - let ind1 = range1.sample(&mut rng); - let range2 = Uniform::new(ind1 + 2, path.len()); - let ind2 = range2.sample(&mut rng); - let mut base_point = path[ind1].clone(); - let point2 = path[ind2].clone(); - let mut is_searching = true; - while is_searching { - let diff_dist = squared_euclidean(&base_point, &point2).sqrt(); - if diff_dist < extend_length { - // reached! - // remove path[ind1+1] ... path[ind2-1] - let remove_index = ind1 + 1; - for _ in 0..(ind2 - ind1 - 1) { - path.remove(remove_index); - } - if path.len() == 2 { - return; - } - is_searching = false; - } else { - let check_point = base_point - .iter() - .zip(point2.iter()) - .map(|(near, target)| *near + (*target - *near) * extend_length / diff_dist) - .collect::>(); - if !is_free(&check_point) { - // trapped - is_searching = false; - } else { - // continue to extend - base_point = check_point; - } - } - } - } -} - -#[test] -fn it_works() { - use rand::distributions::{Distribution, Uniform}; - let mut result = dual_rrt_connect( - &[-1.2, 0.0], - &[1.2, 0.0], - |p: &[f64]| !(p[0].abs() < 1.0 && p[1].abs() < 1.0), - || { - let between = Uniform::new(-2.0, 2.0); - let mut rng = rand::thread_rng(); - vec![between.sample(&mut rng), between.sample(&mut rng)] - }, - 0.2, - 1000, - ) - .unwrap(); - println!("{result:?}"); - assert!(result.len() >= 4); - smooth_path( - &mut result, - |p: &[f64]| !(p[0].abs() < 1.0 && p[1].abs() < 1.0), - 0.2, - 100, - ); - println!("{result:?}"); - assert!(result.len() >= 3); -} +pub mod rrt; +pub mod rrtstar; diff --git a/src/rrt.rs b/src/rrt.rs new file mode 100644 index 0000000..9c8f963 --- /dev/null +++ b/src/rrt.rs @@ -0,0 +1,265 @@ +/* + Copyright 2017 Takashi Ogura + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +#![doc = include_str!("../README.md")] +#![warn(missing_docs)] + +use kdtree::distance::squared_euclidean; +use num_traits::float::Float; +use num_traits::identities::Zero; +use rand::distributions::{Distribution, Uniform}; +use std::fmt::Debug; +use std::mem; +use tracing::debug; + +#[derive(Debug)] +enum ExtendStatus { + Reached(usize), + Advanced(usize), + Trapped, +} + +/// Node that contains user data +#[derive(Debug, Clone)] +struct Node { + parent_index: Option, + data: T, +} + +impl Node { + fn new(data: T) -> Self { + Node { + parent_index: None, + data, + } + } +} + +/// RRT +#[derive(Debug)] +struct Tree +where + N: Float + Zero + Debug, +{ + kdtree: kdtree::KdTree>, + vertices: Vec>>, + name: &'static str, +} + +impl Tree +where + N: Float + Zero + Debug, +{ + fn new(name: &'static str, dim: usize) -> Self { + Tree { + kdtree: kdtree::KdTree::new(dim), + vertices: Vec::new(), + name, + } + } + fn add_vertex(&mut self, q: &[N]) -> usize { + let index = self.vertices.len(); + self.kdtree.add(q.to_vec(), index).unwrap(); + self.vertices.push(Node::new(q.to_vec())); + index + } + fn add_edge(&mut self, q1_index: usize, q2_index: usize) { + self.vertices[q2_index].parent_index = Some(q1_index); + } + fn get_nearest_index(&self, q: &[N]) -> usize { + *self.kdtree.nearest(q, 1, &squared_euclidean).unwrap()[0].1 + } + fn extend(&mut self, q_target: &[N], extend_length: N, is_free: &mut FF) -> ExtendStatus + where + FF: FnMut(&[N]) -> bool, + { + assert!(extend_length > N::zero()); + let nearest_index = self.get_nearest_index(q_target); + let nearest_q = &self.vertices[nearest_index].data; + let diff_dist = squared_euclidean(q_target, nearest_q).sqrt(); + let q_new = if diff_dist < extend_length { + q_target.to_vec() + } else { + nearest_q + .iter() + .zip(q_target) + .map(|(near, target)| *near + (*target - *near) * extend_length / diff_dist) + .collect::>() + }; + debug!("q_new={q_new:?}"); + if is_free(&q_new) { + let new_index = self.add_vertex(&q_new); + self.add_edge(nearest_index, new_index); + if squared_euclidean(&q_new, q_target).sqrt() < extend_length { + return ExtendStatus::Reached(new_index); + } + debug!("target = {q_target:?}"); + debug!("advanced to {q_target:?}"); + return ExtendStatus::Advanced(new_index); + } + ExtendStatus::Trapped + } + fn connect(&mut self, q_target: &[N], extend_length: N, is_free: &mut FF) -> ExtendStatus + where + FF: FnMut(&[N]) -> bool, + { + loop { + debug!("connecting...{q_target:?}"); + match self.extend(q_target, extend_length, is_free) { + ExtendStatus::Trapped => return ExtendStatus::Trapped, + ExtendStatus::Reached(index) => return ExtendStatus::Reached(index), + ExtendStatus::Advanced(_) => {} + }; + } + } + fn get_until_root(&self, index: usize) -> Vec> { + let mut nodes = Vec::new(); + let mut cur_index = index; + while let Some(parent_index) = self.vertices[cur_index].parent_index { + cur_index = parent_index; + nodes.push(self.vertices[cur_index].data.clone()) + } + nodes + } +} + +/// search the path from start to goal which is free, using random_sample function +pub fn dual_rrt_connect( + start: &[N], + goal: &[N], + mut is_free: FF, + random_sample: FR, + extend_length: N, + num_max_try: usize, +) -> Result>, String> +where + FF: FnMut(&[N]) -> bool, + FR: Fn() -> Vec, + N: Float + Debug, +{ + assert_eq!(start.len(), goal.len()); + let mut tree_a = Tree::new("start", start.len()); + let mut tree_b = Tree::new("goal", start.len()); + tree_a.add_vertex(start); + tree_b.add_vertex(goal); + for _ in 0..num_max_try { + debug!("tree_a = {:?}", tree_a.vertices.len()); + debug!("tree_b = {:?}", tree_b.vertices.len()); + let q_rand = random_sample(); + let extend_status = tree_a.extend(&q_rand, extend_length, &mut is_free); + match extend_status { + ExtendStatus::Trapped => {} + ExtendStatus::Advanced(new_index) | ExtendStatus::Reached(new_index) => { + let q_new = &tree_a.vertices[new_index].data; + if let ExtendStatus::Reached(reach_index) = + tree_b.connect(q_new, extend_length, &mut is_free) + { + let mut a_all = tree_a.get_until_root(new_index); + let mut b_all = tree_b.get_until_root(reach_index); + a_all.reverse(); + a_all.append(&mut b_all); + if tree_b.name == "start" { + a_all.reverse(); + } + return Ok(a_all); + } + } + } + mem::swap(&mut tree_a, &mut tree_b); + } + Err("failed".to_string()) +} + +/// select random two points, and try to connect. +pub fn smooth_path( + path: &mut Vec>, + mut is_free: FF, + extend_length: N, + num_max_try: usize, +) where + FF: FnMut(&[N]) -> bool, + N: Float + Debug, +{ + if path.len() < 3 { + return; + } + let mut rng = rand::thread_rng(); + for _ in 0..num_max_try { + let range1 = Uniform::new(0, path.len() - 2); + let ind1 = range1.sample(&mut rng); + let range2 = Uniform::new(ind1 + 2, path.len()); + let ind2 = range2.sample(&mut rng); + let mut base_point = path[ind1].clone(); + let point2 = path[ind2].clone(); + let mut is_searching = true; + while is_searching { + let diff_dist = squared_euclidean(&base_point, &point2).sqrt(); + if diff_dist < extend_length { + // reached! + // remove path[ind1+1] ... path[ind2-1] + let remove_index = ind1 + 1; + for _ in 0..(ind2 - ind1 - 1) { + path.remove(remove_index); + } + if path.len() == 2 { + return; + } + is_searching = false; + } else { + let check_point = base_point + .iter() + .zip(point2.iter()) + .map(|(near, target)| *near + (*target - *near) * extend_length / diff_dist) + .collect::>(); + if !is_free(&check_point) { + // trapped + is_searching = false; + } else { + // continue to extend + base_point = check_point; + } + } + } + } +} + +#[test] +fn it_works() { + use rand::distributions::{Distribution, Uniform}; + let mut result = dual_rrt_connect( + &[-1.2, 0.0], + &[1.2, 0.0], + |p: &[f64]| !(p[0].abs() < 1.0 && p[1].abs() < 1.0), + || { + let between = Uniform::new(-2.0, 2.0); + let mut rng = rand::thread_rng(); + vec![between.sample(&mut rng), between.sample(&mut rng)] + }, + 0.2, + 1000, + ) + .unwrap(); + println!("{result:?}"); + assert!(result.len() >= 4); + smooth_path( + &mut result, + |p: &[f64]| !(p[0].abs() < 1.0 && p[1].abs() < 1.0), + 0.2, + 100, + ); + println!("{result:?}"); + assert!(result.len() >= 3); +} diff --git a/src/rrtstar.rs b/src/rrtstar.rs new file mode 100644 index 0000000..f2b9c6e --- /dev/null +++ b/src/rrtstar.rs @@ -0,0 +1,365 @@ +/* + Copyright 2017 Takashi Ogura + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +#![doc = include_str!("../README.md")] +#![warn(missing_docs)] + +use kdtree::distance::squared_euclidean; +use num_traits::float::Float; +use num_traits::identities::Zero; +use rand::{ + distributions::{Distribution, Uniform}, + RngCore, +}; +use std::fmt::Debug; + +// #[derive(Debug)] +// enum ExtendStatus { +// Reached(usize), +// Advanced(usize), +// Trapped, +// } + +/// Trait to express a weight/cost for a node in the tree +pub trait Weight: Float + Zero {} + +impl Weight for f64 {} +impl Weight for f32 {} + +/// Node that contains user data +#[derive(Debug, Clone)] +pub struct Node { + pub parent_index: Option, + pub data: T, + pub weight: W, +} + +impl Node { + fn new(data: T, weight: W) -> Self { + Node { + parent_index: None, + data, + weight, + } + } +} + +/// RRT +#[derive(Debug)] +pub struct Tree +where + N: Float + Zero + Debug, + W: Weight, +{ + /// kdtree data structure to store the nodes + /// for fast nearest neighbour search + pub kdtree: kdtree::KdTree>, + /// Vertices of the tree + pub vertices: Vec, W>>, + /// The goal index + pub goal_index: Option, +} + +// impl default for Tree +impl Default for Tree +where + N: Float + Zero + Debug, + W: Weight, +{ + fn default() -> Self { + Tree { + kdtree: kdtree::KdTree::new(2), + vertices: Vec::new(), + goal_index: None, + } + } +} + +impl Tree +where + N: Float + Zero + Debug, + W: Weight, +{ + fn new(dim: usize) -> Self { + Tree { + kdtree: kdtree::KdTree::new(dim), + vertices: Vec::new(), + goal_index: None, + } + } + + // Add a vertex to the tree + fn add_vertex(&mut self, q: &[N], weight: W) -> usize { + let index = self.vertices.len(); + self.kdtree.add(q.to_vec(), index).unwrap(); + self.vertices.push(Node::new(q.to_vec(), weight)); + index + } + + // + fn add_edge(&mut self, q1_index: usize, q2_index: usize) { + self.vertices[q2_index].parent_index = Some(q1_index); + } + + fn remove_edge(&mut self, q_index: usize) { + self.vertices[q_index].parent_index = None; + } + + // + fn get_nearest_index(&self, q: &[N]) -> usize { + *self.kdtree.nearest(q, 1, &squared_euclidean).unwrap()[0].1 + } + + /// Get the path from the root to the node + pub fn get_until_root(&self, index: usize) -> Vec> { + let mut nodes = Vec::new(); + let mut cur_index = index; + while let Some(parent_index) = self.vertices[cur_index].parent_index { + cur_index = parent_index; + nodes.push(self.vertices[cur_index].data.clone()) + } + nodes + } + + // Get indices of nerest nodes within a radius + fn get_nearest_neighbours(&self, q_new: &[N], extend_length: N) -> Vec { + self.kdtree + .within(q_new, extend_length.powi(2), &squared_euclidean) + .unwrap_or(vec![]) + .iter() + .map(|(_, index)| **index) + .collect() + } +} + +/// RRT* error +#[derive(Debug, derive_more::Error, derive_more::Display)] +pub enum RRTStarError { + /// Failed to find a path within the maximum number of iterations + #[display(fmt = "Failed to find a path within the maximum number of iterations")] + MaxItersReached, +} + +// pub type RRTStarResult = Result>, RRTStarError>; +/// This is the return type for rrtstar +pub type RRTStarResult = Result, RRTStarError>; + +/// search the path from start to goal which is free, using random_sample function +/// https://erc-bpgc.github.io/handbook/automation/PathPlanners/Sampling_Based_Algorithms/RRT_Star/ +pub fn rrtstar( + start: &[N], + goal: &[N], + mut is_collision_free: impl FnMut(&[N]) -> bool, + mut random_sample: impl FnMut() -> Vec, + extend_length: N, + max_iters: usize, + neighbourhood_radius: N, + stop_when_reach_goal: bool, +) -> RRTStarResult +// ) -> Result>, RRTStarError> +where + // FF: FnMut(&[N]) -> bool, + // FR: Fn() -> Vec, + N: Float + Debug, + // W: Weight, +{ + assert_eq!(start.len(), goal.len()); + let mut tree = Tree::::new(start.len()); + tree.add_vertex(start, 0.0); + + let mut goal_reached = false; + + // Path finding loop + for _ in 0..max_iters { + // 1. Random sample + let q_rand = random_sample(); + // 2. Nearest neighbour + let nearest_index = tree.get_nearest_index(&q_rand); + let q_nearest = &tree.vertices[nearest_index].data; + // 3. Steer to get new point + let diff_dist = squared_euclidean(q_rand.as_slice(), q_nearest.as_slice()).sqrt(); + let q_new = if diff_dist < extend_length { + q_rand.to_vec() + } else { + q_nearest + .iter() + .zip(q_rand) + .map(|(near, target)| *near + (target - *near) * extend_length / diff_dist) + .collect::>() + }; + + // 4. Check if the new point is free + if !is_collision_free(&q_new) { + continue; + } + + // 5. Connect to the new point + // 5.1. Find nearest neighbours + let nearest = tree.get_nearest_neighbours(&q_new, neighbourhood_radius); + // 5.2. Insert the new point to the tree + let parent_weight = tree.vertices[nearest_index].weight; + let edge_weight = ::from::(extend_length) + .expect("N implements Float, same as W"); + let cost_min = parent_weight + edge_weight; + + let new_index = tree.add_vertex(&q_new, cost_min); + // 5.3. Connect to lowest cost path + let min_index = std::iter::once(&nearest_index) + .chain(nearest.iter()) + .min_by(|&a, &b| { + let a_potential_weight = tree.vertices[*a].weight + + ::from( + squared_euclidean(&q_new, &tree.vertices[*a].data).sqrt(), + ) + .expect("N implements Float, same as W"); + + let b_potential_weight = tree.vertices[*b].weight + + ::from( + squared_euclidean(&q_new, &tree.vertices[*b].data).sqrt(), + ) + .expect("N implements Float, same as W"); + + a_potential_weight + .partial_cmp(&b_potential_weight) + .expect("Weight W of two nodes should be comparable") + }) + .expect("iterator shouldn't be empty"); + + tree.add_edge(*min_index, new_index); + + // 5.4. Rewire + for &near_index in nearest.iter() { + let near_weight = tree.vertices[near_index].weight; + let new_potential_cost = cost_min + + ::from( + squared_euclidean(&q_new, &tree.vertices[near_index].data).sqrt(), + ) + .expect("N implements Float, same as W"); + + if new_potential_cost < near_weight { + tree.remove_edge(near_index); + tree.add_edge(new_index, near_index); + tree.vertices[near_index].weight = new_potential_cost; + } + } + + // 6. Check if the goal is reached + if !goal_reached && squared_euclidean(&q_new, goal).sqrt() < extend_length { + let goal_weight = tree.vertices[new_index].weight + + ::from(squared_euclidean(&q_new, goal).sqrt()) + .expect("N implements Float, same as W"); + // println!("goal {:?} reached with weight {}", goal, goal_weight); + let goal_index = tree.add_vertex(goal, goal_weight); + tree.add_edge(new_index, goal_index); + + tree.goal_index = Some(goal_index); + + goal_reached = true; + + if stop_when_reach_goal { + return Ok(tree); + } + } + } + + if !stop_when_reach_goal { + return Ok(tree); + } else { + Err(RRTStarError::MaxItersReached) + } +} + +/// select random two points, and try to connect. +pub fn smooth_path( + path: &mut Vec>, + mut is_free: FF, + extend_length: N, + num_max_try: usize, + mut rng: &mut dyn RngCore, +) where + FF: FnMut(&[N]) -> bool, + N: Float + Debug, +{ + if path.len() < 3 { + return; + } + // let mut rng = rand::thread_rng(); + for _ in 0..num_max_try { + let range1 = Uniform::new(0, path.len() - 2); + let ind1 = range1.sample(&mut rng); + let range2 = Uniform::new(ind1 + 2, path.len()); + let ind2 = range2.sample(&mut rng); + let mut base_point = path[ind1].clone(); + let point2 = path[ind2].clone(); + let mut is_searching = true; + while is_searching { + let diff_dist = squared_euclidean(&base_point, &point2).sqrt(); + if diff_dist < extend_length { + // reached! + // remove path[ind1+1] ... path[ind2-1] + let remove_index = ind1 + 1; + for _ in 0..(ind2 - ind1 - 1) { + path.remove(remove_index); + } + if path.len() == 2 { + return; + } + is_searching = false; + } else { + let check_point = base_point + .iter() + .zip(point2.iter()) + .map(|(near, target)| *near + (*target - *near) * extend_length / diff_dist) + .collect::>(); + if !is_free(&check_point) { + // trapped + is_searching = false; + } else { + // continue to extend + base_point = check_point; + } + } + } + } +} + +#[test] +fn it_works() { + use rand::distributions::{Distribution, Uniform}; + let mut result = rrtstar( + &[-1.2, 0.0], + &[1.2, 0.0], + |p: &[f64]| !(p[0].abs() < 1.0 && p[1].abs() < 1.0), + || { + let between = Uniform::new(-2.0, 2.0); + let mut rng = rand::thread_rng(); + vec![between.sample(&mut rng), between.sample(&mut rng)] + }, + 0.2, + 1000, + ) + .unwrap(); + println!("{result:?}"); + // assert!(result.len() >= 4); + // smooth_path( + // &mut result, + // |p: &[f64]| !(p[0].abs() < 1.0 && p[1].abs() < 1.0), + // 0.2, + // 100, + // ); + // println!("{result:?}"); + // assert!(result.len() >= 3); +}