Skip to content

Commit ca65261

Browse files
Merge branch 'main' into dialect-derives
2 parents 149effc + 2ac82e9 commit ca65261

File tree

8 files changed

+832
-260
lines changed

8 files changed

+832
-260
lines changed

Cargo.toml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ std = []
4242
recursive-protection = ["std", "recursive"]
4343
# Enable JSON output in the `cli` example:
4444
json_example = ["serde_json", "serde"]
45+
derive-dialect = ["sqlparser_derive"]
4546
visitor = ["sqlparser_derive"]
4647

4748
[dependencies]
@@ -61,6 +62,10 @@ simple_logger = "5.0"
6162
matches = "0.1"
6263
pretty_assertions = "1"
6364

65+
[[test]]
66+
name = "sqlparser_derive_dialect"
67+
required-features = ["derive-dialect"]
68+
6469
[package.metadata.docs.rs]
6570
# Document these features on docs.rs
66-
features = ["serde", "visitor"]
71+
features = ["serde", "visitor", "derive-dialect"]

derive/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,6 @@ edition = "2021"
3636
proc-macro = true
3737

3838
[dependencies]
39-
syn = { version = "2.0", default-features = false, features = ["printing", "parsing", "derive", "proc-macro"] }
39+
syn = { version = "2.0", default-features = false, features = ["full", "printing", "parsing", "derive", "proc-macro", "clone-impls"] }
4040
proc-macro2 = "1.0"
4141
quote = "1.0"

derive/src/dialect.rs

Lines changed: 305 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,305 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! Implementation of the `derive_dialect!` macro for creating custom SQL dialects.
19+
20+
use proc_macro2::TokenStream;
21+
use quote::{quote, quote_spanned};
22+
use std::collections::HashSet;
23+
use syn::{
24+
braced,
25+
parse::{Parse, ParseStream},
26+
Error, File, FnArg, Ident, Item, LitBool, LitChar, Pat, ReturnType, Signature, Token,
27+
TraitItem, Type,
28+
};
29+
30+
/// Override value types supported by the macro
31+
pub(crate) enum Override {
32+
Bool(LitBool),
33+
Char(LitChar),
34+
None,
35+
}
36+
37+
/// Parsed input for the `derive_dialect!` macro
38+
pub(crate) struct DeriveDialectInput {
39+
pub name: Ident,
40+
pub base: Type,
41+
pub preserve_type_id: bool,
42+
pub overrides: Vec<(Ident, Override)>,
43+
}
44+
45+
/// `Dialect` trait method attrs
46+
struct DialectMethod {
47+
name: Ident,
48+
signature: Signature,
49+
}
50+
51+
impl Parse for DeriveDialectInput {
52+
fn parse(input: ParseStream) -> syn::Result<Self> {
53+
let name: Ident = input.parse()?;
54+
input.parse::<Token![,]>()?;
55+
let base: Type = input.parse()?;
56+
57+
let mut preserve_type_id = false;
58+
let mut overrides = Vec::new();
59+
60+
while input.peek(Token![,]) {
61+
input.parse::<Token![,]>()?;
62+
if input.is_empty() {
63+
break;
64+
}
65+
if input.peek(Ident) {
66+
let ident: Ident = input.parse()?;
67+
match ident.to_string().as_str() {
68+
"preserve_type_id" => {
69+
input.parse::<Token![=]>()?;
70+
preserve_type_id = input.parse::<LitBool>()?.value();
71+
}
72+
"overrides" => {
73+
input.parse::<Token![=]>()?;
74+
let content;
75+
braced!(content in input);
76+
while !content.is_empty() {
77+
let key: Ident = content.parse()?;
78+
content.parse::<Token![=]>()?;
79+
let value = if content.peek(LitBool) {
80+
Override::Bool(content.parse()?)
81+
} else if content.peek(LitChar) {
82+
Override::Char(content.parse()?)
83+
} else if content.peek(Ident) {
84+
let ident: Ident = content.parse()?;
85+
if ident == "None" {
86+
Override::None
87+
} else {
88+
return Err(Error::new(
89+
ident.span(),
90+
format!("Expected `true`, `false`, a char, or `None`, found `{ident}`"),
91+
));
92+
}
93+
} else {
94+
return Err(
95+
content.error("Expected `true`, `false`, a char, or `None`")
96+
);
97+
};
98+
overrides.push((key, value));
99+
if content.peek(Token![,]) {
100+
content.parse::<Token![,]>()?;
101+
}
102+
}
103+
}
104+
other => {
105+
return Err(Error::new(ident.span(), format!(
106+
"Unknown argument `{other}`. Expected `preserve_type_id` or `overrides`."
107+
)));
108+
}
109+
}
110+
}
111+
}
112+
Ok(DeriveDialectInput {
113+
name,
114+
base,
115+
preserve_type_id,
116+
overrides,
117+
})
118+
}
119+
}
120+
121+
/// Entry point for the `derive_dialect!` macro
122+
pub(crate) fn derive_dialect(input: DeriveDialectInput) -> proc_macro::TokenStream {
123+
let err = |msg: String| {
124+
Error::new(proc_macro2::Span::call_site(), msg)
125+
.to_compile_error()
126+
.into()
127+
};
128+
129+
let source = match read_dialect_mod_file() {
130+
Ok(s) => s,
131+
Err(e) => return err(format!("Failed to read dialect/mod.rs: {e}")),
132+
};
133+
let file: File = match syn::parse_str(&source) {
134+
Ok(f) => f,
135+
Err(e) => return err(format!("Failed to parse source: {e}")),
136+
};
137+
let methods = match extract_dialect_methods(&file) {
138+
Ok(m) => m,
139+
Err(e) => return e.to_compile_error().into(),
140+
};
141+
142+
// Validate overrides
143+
let bool_names: HashSet<_> = methods
144+
.iter()
145+
.filter(|m| is_bool_method(&m.signature))
146+
.map(|m| m.name.to_string())
147+
.collect();
148+
for (key, value) in &input.overrides {
149+
let key_str = key.to_string();
150+
let err = |msg| Error::new(key.span(), msg).to_compile_error().into();
151+
match value {
152+
Override::Bool(_) if !bool_names.contains(&key_str) => {
153+
return err(format!("Unknown boolean method `{key_str}`"));
154+
}
155+
Override::Char(_) | Override::None if key_str != "identifier_quote_style" => {
156+
return err(format!(
157+
"Char/None only valid for `identifier_quote_style`, not `{key_str}`"
158+
));
159+
}
160+
_ => {}
161+
}
162+
}
163+
generate_derived_dialect(&input, &methods).into()
164+
}
165+
166+
/// Generate the complete derived `Dialect` implementation
167+
fn generate_derived_dialect(input: &DeriveDialectInput, methods: &[DialectMethod]) -> TokenStream {
168+
let name = &input.name;
169+
let base = &input.base;
170+
171+
// Helper to find an override by method name
172+
let find_override = |method_name: &str| {
173+
input
174+
.overrides
175+
.iter()
176+
.find(|(k, _)| k == method_name)
177+
.map(|(_, v)| v)
178+
};
179+
180+
// Helper to generate delegation to base dialect
181+
let delegate = |method: &DialectMethod| {
182+
let sig = &method.signature;
183+
let method_name = &method.name;
184+
let params = extract_param_names(sig);
185+
quote_spanned! { method_name.span() => #sig { self.dialect.#method_name(#(#params),*) } }
186+
};
187+
188+
// Generate the struct
189+
let struct_def = quote_spanned! { name.span() =>
190+
#[derive(Debug, Default)]
191+
pub struct #name {
192+
dialect: #base,
193+
}
194+
impl #name {
195+
pub fn new() -> Self { Self::default() }
196+
}
197+
};
198+
199+
// Generate TypeId method body
200+
let type_id_body = if input.preserve_type_id {
201+
quote! { Dialect::dialect(&self.dialect) }
202+
} else {
203+
quote! { ::core::any::TypeId::of::<#name>() }
204+
};
205+
206+
// Generate method implementations
207+
let method_impls = methods.iter().map(|method| {
208+
let method_name = &method.name;
209+
match find_override(&method_name.to_string()) {
210+
Some(Override::Bool(value)) => {
211+
quote_spanned! { method_name.span() => fn #method_name(&self) -> bool { #value } }
212+
}
213+
Some(Override::Char(c)) => {
214+
quote_spanned! { method_name.span() =>
215+
fn identifier_quote_style(&self, _: &str) -> Option<char> { Some(#c) }
216+
}
217+
}
218+
Some(Override::None) => {
219+
quote_spanned! { method_name.span() =>
220+
fn identifier_quote_style(&self, _: &str) -> Option<char> { None }
221+
}
222+
}
223+
None => delegate(method),
224+
}
225+
});
226+
227+
// Wrap impl in a const block with scoped imports so types resolve without qualification
228+
quote! {
229+
#struct_def
230+
const _: () = {
231+
use ::core::iter::Peekable;
232+
use ::core::str::Chars;
233+
use sqlparser::ast::{ColumnOption, Expr, GranteesType, Ident, ObjectNamePart, Statement};
234+
use sqlparser::dialect::{Dialect, Precedence};
235+
use sqlparser::keywords::Keyword;
236+
use sqlparser::parser::{Parser, ParserError};
237+
238+
impl Dialect for #name {
239+
fn dialect(&self) -> ::core::any::TypeId { #type_id_body }
240+
#(#method_impls)*
241+
}
242+
};
243+
}
244+
}
245+
246+
/// Extract parameter names from a method signature (excluding self)
247+
fn extract_param_names(sig: &Signature) -> Vec<&Ident> {
248+
sig.inputs
249+
.iter()
250+
.filter_map(|arg| match arg {
251+
FnArg::Typed(pt) => match pt.pat.as_ref() {
252+
Pat::Ident(pi) => Some(&pi.ident),
253+
_ => None,
254+
},
255+
_ => None,
256+
})
257+
.collect()
258+
}
259+
260+
/// Read the `dialect/mod.rs` file that contains the Dialect trait.
261+
fn read_dialect_mod_file() -> Result<String, String> {
262+
let manifest_dir =
263+
std::env::var("CARGO_MANIFEST_DIR").map_err(|_| "CARGO_MANIFEST_DIR not set")?;
264+
let path = std::path::Path::new(&manifest_dir).join("src/dialect/mod.rs");
265+
std::fs::read_to_string(&path).map_err(|e| format!("Failed to read {}: {e}", path.display()))
266+
}
267+
268+
/// Extract all methods from the `Dialect` trait (excluding `dialect` for TypeId)
269+
fn extract_dialect_methods(file: &File) -> Result<Vec<DialectMethod>, Error> {
270+
let dialect_trait = file
271+
.items
272+
.iter()
273+
.find_map(|item| match item {
274+
Item::Trait(t) if t.ident == "Dialect" => Some(t),
275+
_ => None,
276+
})
277+
.ok_or_else(|| Error::new(proc_macro2::Span::call_site(), "Dialect trait not found"))?;
278+
279+
let mut methods: Vec<_> = dialect_trait
280+
.items
281+
.iter()
282+
.filter_map(|item| match item {
283+
TraitItem::Fn(m) if m.sig.ident != "dialect" => Some(DialectMethod {
284+
name: m.sig.ident.clone(),
285+
signature: m.sig.clone(),
286+
}),
287+
_ => None,
288+
})
289+
.collect();
290+
methods.sort_by_key(|m| m.name.to_string());
291+
Ok(methods)
292+
}
293+
294+
/// Check if a method signature is `fn name(&self) -> bool`
295+
fn is_bool_method(sig: &Signature) -> bool {
296+
sig.inputs.len() == 1
297+
&& matches!(
298+
sig.inputs.first(),
299+
Some(FnArg::Receiver(r)) if r.reference.is_some() && r.mutability.is_none()
300+
)
301+
&& matches!(
302+
&sig.output,
303+
ReturnType::Type(_, ty) if matches!(ty.as_ref(), Type::Path(p) if p.path.is_ident("bool"))
304+
)
305+
}

0 commit comments

Comments
 (0)