forked from datafusion-contrib/datafusion-distributed
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlocalhost_run.rs
More file actions
116 lines (101 loc) · 3.49 KB
/
Copy pathlocalhost_run.rs
File metadata and controls
116 lines (101 loc) · 3.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
use arrow::util::pretty::pretty_format_batches;
use arrow_flight::flight_service_client::FlightServiceClient;
use async_trait::async_trait;
use dashmap::{DashMap, Entry};
use datafusion::common::DataFusionError;
use datafusion::execution::SessionStateBuilder;
use datafusion::physical_plan::displayable;
use datafusion::prelude::{ParquetReadOptions, SessionContext};
use datafusion_distributed::{
BoxCloneSyncChannel, ChannelResolver, DistributedExt, DistributedPhysicalOptimizerRule,
};
use futures::TryStreamExt;
use std::error::Error;
use std::sync::Arc;
use structopt::StructOpt;
use tonic::transport::Channel;
use url::Url;
#[derive(StructOpt)]
#[structopt(name = "run", about = "A localhost Distributed DataFusion runner")]
struct Args {
#[structopt()]
query: String,
// --cluster-ports 8080,8081,8082
#[structopt(long = "cluster-ports", use_delimiter = true)]
cluster_ports: Vec<u16>,
#[structopt(long)]
explain: bool,
#[structopt(long, default_value = "3")]
network_shuffle_tasks: usize,
#[structopt(long, default_value = "3")]
network_coalesce_tasks: usize,
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
let args = Args::from_args();
let localhost_resolver = LocalhostChannelResolver {
ports: args.cluster_ports,
cached: DashMap::new(),
};
let state = SessionStateBuilder::new()
.with_default_features()
.with_distributed_channel_resolver(localhost_resolver)
.with_physical_optimizer_rule(Arc::new(
DistributedPhysicalOptimizerRule::new()
.with_network_coalesce_tasks(args.network_coalesce_tasks)
.with_network_shuffle_tasks(args.network_shuffle_tasks),
))
.build();
let ctx = SessionContext::from(state);
ctx.register_parquet(
"flights_1m",
"testdata/flights-1m.parquet",
ParquetReadOptions::default(),
)
.await?;
ctx.register_parquet("weather", "testdata/weather", ParquetReadOptions::default())
.await?;
let df = ctx.sql(&args.query).await?;
if args.explain {
let plan = df.create_physical_plan().await?;
let display = displayable(plan.as_ref()).indent(true).to_string();
println!("{display}");
} else {
let stream = df.execute_stream().await?;
let batches = stream.try_collect::<Vec<_>>().await?;
let formatted = pretty_format_batches(&batches)?;
println!("{formatted}");
}
Ok(())
}
#[derive(Clone)]
struct LocalhostChannelResolver {
ports: Vec<u16>,
cached: DashMap<Url, FlightServiceClient<BoxCloneSyncChannel>>,
}
#[async_trait]
impl ChannelResolver for LocalhostChannelResolver {
fn get_urls(&self) -> Result<Vec<Url>, DataFusionError> {
Ok(self
.ports
.iter()
.map(|port| Url::parse(&format!("http://localhost:{port}")).unwrap())
.collect())
}
async fn get_flight_client_for_url(
&self,
url: &Url,
) -> Result<FlightServiceClient<BoxCloneSyncChannel>, DataFusionError> {
match self.cached.entry(url.clone()) {
Entry::Occupied(v) => Ok(v.get().clone()),
Entry::Vacant(v) => {
let channel = Channel::from_shared(url.to_string())
.unwrap()
.connect_lazy();
let channel = FlightServiceClient::new(BoxCloneSyncChannel::new(channel));
v.insert(channel.clone());
Ok(channel)
}
}
}
}