From b6e24fb58b842f07397c6f1b454ff2fb16ee4a76 Mon Sep 17 00:00:00 2001 From: Ava Johnson Date: Mon, 15 Jun 2026 00:33:31 -0700 Subject: [PATCH 1/4] derive: add function-like macro for full file code generation --- graphql_client_cli/src/generate.rs | 17 +-- graphql_client_codegen/src/codegen_options.rs | 24 +++- .../src/generated_module.rs | 4 +- graphql_client_codegen/src/lib.rs | 2 +- graphql_query_derive/src/attributes.rs | 119 +++++++++--------- graphql_query_derive/src/lib.rs | 54 ++++++-- 6 files changed, 136 insertions(+), 84 deletions(-) diff --git a/graphql_client_cli/src/generate.rs b/graphql_client_cli/src/generate.rs index f6fe05fe..a466d2f2 100644 --- a/graphql_client_cli/src/generate.rs +++ b/graphql_client_cli/src/generate.rs @@ -8,7 +8,6 @@ use std::fs::File; use std::io::Write as _; use std::path::PathBuf; use std::process::Stdio; -use syn::{token::Paren, token::Pub, VisRestricted, Visibility}; pub(crate) struct CliCodegenParams { pub query_path: PathBuf, @@ -36,7 +35,7 @@ pub(crate) fn generate_code(params: CliCodegenParams) -> CliResult<()> { deprecation_strategy, no_formatting, output_directory, - module_visibility: _module_visibility, + module_visibility, query_path, schema_path, selected_operation, @@ -51,19 +50,7 @@ pub(crate) fn generate_code(params: CliCodegenParams) -> CliResult<()> { let mut options = GraphQLClientCodegenOptions::new(CodegenMode::Cli); - options.set_module_visibility(match _module_visibility { - Some(v) => match v.to_lowercase().as_str() { - "pub" => Visibility::Public(Pub::default()), - "inherited" => Visibility::Inherited, - _ => Visibility::Restricted(VisRestricted { - pub_token: Pub::default(), - in_token: None, - paren_token: Paren::default(), - path: syn::parse_str(&v).unwrap(), - }), - }, - None => Visibility::Public(Pub::default()), - }); + options.set_module_visibility_from_str(module_visibility.as_deref().unwrap_or("pub")); options.set_fragments_other_variant(fragments_other_variant); diff --git a/graphql_client_codegen/src/codegen_options.rs b/graphql_client_codegen/src/codegen_options.rs index 7b3d8d73..84b758ef 100644 --- a/graphql_client_codegen/src/codegen_options.rs +++ b/graphql_client_codegen/src/codegen_options.rs @@ -2,7 +2,11 @@ use crate::deprecation::DeprecationStrategy; use crate::normalization::Normalization; use proc_macro2::Ident; use std::path::{Path, PathBuf}; -use syn::{self, Visibility}; +use syn::{ + self, + token::{Paren, Pub}, + VisRestricted, Visibility, +}; /// Which context is this code generation effort taking place. #[derive(Debug)] @@ -11,6 +15,8 @@ pub enum CodegenMode { Cli, /// The derive macro defined in graphql_query_derive. Derive, + /// The function-like macro defined in graphql_queries. + FunctionLike, } /// Used to configure code generation. @@ -174,6 +180,22 @@ impl GraphQLClientCodegenOptions { self.module_visibility = Some(visibility); } + /// Parse target module visibility from a string. + pub fn set_module_visibility_from_str(&mut self, visibility: &str) { + let visibility = match visibility.to_lowercase().as_str() { + "pub" => Visibility::Public(Pub::default()), + "inherited" => Visibility::Inherited, + _ => Visibility::Restricted(VisRestricted { + pub_token: Pub::default(), + in_token: None, + paren_token: Paren::default(), + path: syn::parse_str(&visibility).unwrap(), + }), + }; + + self.module_visibility = Some(visibility); + } + /// The name of implementation target struct. pub fn set_struct_name(&mut self, struct_name: String) { self.struct_name = Some(struct_name); diff --git a/graphql_client_codegen/src/generated_module.rs b/graphql_client_codegen/src/generated_module.rs index b225d001..c57aa1e7 100644 --- a/graphql_client_codegen/src/generated_module.rs +++ b/graphql_client_codegen/src/generated_module.rs @@ -79,7 +79,9 @@ impl GeneratedModule<'_> { let impls = self.build_impls()?; let struct_declaration: Option<_> = match self.options.mode { - CodegenMode::Cli => Some(quote!(#module_visibility struct #operation_name_ident;)), + CodegenMode::Cli | CodegenMode::FunctionLike => { + Some(quote!(#module_visibility struct #operation_name_ident;)) + } // The struct is already present in derive mode. CodegenMode::Derive => None, }; diff --git a/graphql_client_codegen/src/lib.rs b/graphql_client_codegen/src/lib.rs index 562691ce..1e220929 100644 --- a/graphql_client_codegen/src/lib.rs +++ b/graphql_client_codegen/src/lib.rs @@ -136,7 +136,7 @@ fn generate_module_token_stream_inner( let operations = match (operations, &options.mode) { (Some(ops), _) => ops, - (None, &CodegenMode::Cli) => query.operations().collect(), + (None, &CodegenMode::Cli | &CodegenMode::FunctionLike) => query.operations().collect(), (None, &CodegenMode::Derive) => { return Err(GeneralError(derive_operation_not_found_error( options.struct_ident(), diff --git a/graphql_query_derive/src/attributes.rs b/graphql_query_derive/src/attributes.rs index 535914fb..ab1b3f4e 100644 --- a/graphql_query_derive/src/attributes.rs +++ b/graphql_query_derive/src/attributes.rs @@ -1,4 +1,4 @@ -use proc_macro2::TokenTree; +use proc_macro2::{TokenStream, TokenTree}; use std::str::FromStr; use syn::Meta; @@ -8,82 +8,75 @@ use graphql_client_codegen::normalization::Normalization; const DEPRECATION_ERROR: &str = "deprecated must be one of 'allow', 'deny', or 'warn'"; const NORMALIZATION_ERROR: &str = "normalization must be one of 'none' or 'rust'"; -pub fn ident_exists(ast: &syn::DeriveInput, ident: &str) -> Result<(), syn::Error> { +/// Extract a token stream of the `graphql` attribute. +pub fn extract_tokens(ast: &syn::DeriveInput) -> Result { let attribute = ast .attrs .iter() .find(|attr| attr.path().is_ident("graphql")) .ok_or_else(|| syn::Error::new_spanned(ast, "The graphql attribute is missing"))?; - if let Meta::List(list) = &attribute.meta { - for item in list.tokens.clone().into_iter() { - if let TokenTree::Ident(ident_) = item { - if ident_ == ident { - return Ok(()); - } + match &attribute.meta { + Meta::List(list) => Ok(list.tokens.clone()), + _ => Err(syn::Error::new_spanned( + ast, + "Unable to parse the graphql attribute", + )), + } +} + +pub fn ident_exists(tokens: &TokenStream, ident: &str) -> Result<(), syn::Error> { + for item in tokens.clone().into_iter() { + if let TokenTree::Ident(ident_) = item { + if ident_ == ident { + return Ok(()); } } } Err(syn::Error::new_spanned( - ast, + tokens, format!("Ident `{}` not found", ident), )) } /// Extract an configuration parameter specified in the `graphql` attribute. -pub fn extract_attr(ast: &syn::DeriveInput, attr: &str) -> Result { - let attribute = ast - .attrs - .iter() - .find(|a| a.path().is_ident("graphql")) - .ok_or_else(|| syn::Error::new_spanned(ast, "The graphql attribute is missing"))?; - - if let Meta::List(list) = &attribute.meta { - let mut iter = list.tokens.clone().into_iter(); - while let Some(item) = iter.next() { - if let TokenTree::Ident(ident) = item { - if ident == attr { - iter.next(); - if let Some(TokenTree::Literal(lit)) = iter.next() { - let lit_str: syn::LitStr = syn::parse_str(&lit.to_string())?; - return Ok(lit_str.value()); - } +pub fn extract_attr(tokens: &TokenStream, attr: &str) -> Result { + let mut iter = tokens.clone().into_iter(); + while let Some(item) = iter.next() { + if let TokenTree::Ident(ident) = item { + if ident == attr { + iter.next(); + if let Some(TokenTree::Literal(lit)) = iter.next() { + let lit_str: syn::LitStr = syn::parse_str(&lit.to_string())?; + return Ok(lit_str.value()); } } } } Err(syn::Error::new_spanned( - ast, + tokens, format!("Attribute `{}` not found", attr), )) } /// Extract a list of configuration parameter values specified in the `graphql` attribute. -pub fn extract_attr_list(ast: &syn::DeriveInput, attr: &str) -> Result, syn::Error> { - let attribute = ast - .attrs - .iter() - .find(|a| a.path().is_ident("graphql")) - .ok_or_else(|| syn::Error::new_spanned(ast, "The graphql attribute is missing"))?; - +pub fn extract_attr_list(tokens: &TokenStream, attr: &str) -> Result, syn::Error> { let mut result = Vec::new(); - if let Meta::List(list) = &attribute.meta { - let mut iter = list.tokens.clone().into_iter(); - while let Some(item) = iter.next() { - if let TokenTree::Ident(ident) = item { - if ident == attr { - if let Some(TokenTree::Group(group)) = iter.next() { - for token in group.stream() { - if let TokenTree::Literal(lit) = token { - let lit_str: syn::LitStr = syn::parse_str(&lit.to_string())?; - result.push(lit_str.value()); - } + let mut iter = tokens.clone().into_iter(); + while let Some(item) = iter.next() { + if let TokenTree::Ident(ident) = item { + if ident == attr { + if let Some(TokenTree::Group(group)) = iter.next() { + for token in group.stream() { + if let TokenTree::Literal(lit) = token { + let lit_str: syn::LitStr = syn::parse_str(&lit.to_string())?; + result.push(lit_str.value()); } - return Ok(result); } + return Ok(result); } } } @@ -91,7 +84,7 @@ pub fn extract_attr_list(ast: &syn::DeriveInput, attr: &str) -> Result Result Result { - extract_attr(ast, "deprecated")? + extract_attr(tokens, "deprecated")? .to_lowercase() .as_str() .parse() - .map_err(|_| syn::Error::new_spanned(ast, DEPRECATION_ERROR.to_owned())) + .map_err(|_| syn::Error::new_spanned(tokens, DEPRECATION_ERROR.to_owned())) } /// Get the deprecation from a struct attribute in the derive case. -pub fn extract_normalization(ast: &syn::DeriveInput) -> Result { - extract_attr(ast, "normalization")? +pub fn extract_normalization(tokens: &TokenStream) -> Result { + extract_attr(tokens, "normalization")? .to_lowercase() .as_str() .parse() - .map_err(|_| syn::Error::new_spanned(ast, NORMALIZATION_ERROR)) + .map_err(|_| syn::Error::new_spanned(tokens, NORMALIZATION_ERROR)) } -pub fn extract_fragments_other_variant(ast: &syn::DeriveInput) -> bool { - extract_attr(ast, "fragments_other_variant") +pub fn extract_fragments_other_variant(tokens: &TokenStream) -> bool { + extract_attr(tokens, "fragments_other_variant") .ok() .and_then(|s| FromStr::from_str(s.as_str()).ok()) .unwrap_or(false) } -pub fn extract_skip_serializing_none(ast: &syn::DeriveInput) -> bool { - ident_exists(ast, "skip_serializing_none").is_ok() +pub fn extract_skip_serializing_none(tokens: &TokenStream) -> bool { + ident_exists(tokens, "skip_serializing_none").is_ok() } #[cfg(test)] @@ -146,6 +139,7 @@ mod test { struct MyQuery; "; let parsed = syn::parse_str(input).unwrap(); + let parsed = extract_tokens(&parsed).unwrap(); assert_eq!( extract_deprecation_strategy(&parsed).unwrap(), DeprecationStrategy::Warn @@ -164,6 +158,7 @@ mod test { struct MyQuery; "; let parsed = syn::parse_str(input).unwrap(); + let parsed = extract_tokens(&parsed).unwrap(); assert_eq!( extract_deprecation_strategy(&parsed).unwrap(), DeprecationStrategy::Deny @@ -182,6 +177,7 @@ mod test { struct MyQuery; "; let parsed = syn::parse_str(input).unwrap(); + let parsed = extract_tokens(&parsed).unwrap(); match extract_deprecation_strategy(&parsed) { Ok(_) => panic!("parsed unexpectedly"), Err(e) => assert_eq!(&format!("{}", e), DEPRECATION_ERROR), @@ -200,6 +196,7 @@ mod test { struct MyQuery; "; let parsed = syn::parse_str(input).unwrap(); + let parsed = extract_tokens(&parsed).unwrap(); assert!(extract_fragments_other_variant(&parsed)); } @@ -215,6 +212,7 @@ mod test { struct MyQuery; "; let parsed = syn::parse_str(input).unwrap(); + let parsed = extract_tokens(&parsed).unwrap(); assert!(!extract_fragments_other_variant(&parsed)); } @@ -230,6 +228,7 @@ mod test { struct MyQuery; "; let parsed = syn::parse_str(input).unwrap(); + let parsed = extract_tokens(&parsed).unwrap(); assert!(!extract_fragments_other_variant(&parsed)); } @@ -244,6 +243,7 @@ mod test { struct MyQuery; "; let parsed = syn::parse_str(input).unwrap(); + let parsed = extract_tokens(&parsed).unwrap(); assert!(!extract_fragments_other_variant(&parsed)); } @@ -259,6 +259,7 @@ mod test { struct MyQuery; "#; let parsed = syn::parse_str(input).unwrap(); + let parsed = extract_tokens(&parsed).unwrap(); assert!(extract_skip_serializing_none(&parsed)); } @@ -273,6 +274,7 @@ mod test { struct MyQuery; "#; let parsed = syn::parse_str(input).unwrap(); + let parsed = extract_tokens(&parsed).unwrap(); assert!(!extract_skip_serializing_none(&parsed)); } @@ -289,6 +291,7 @@ mod test { struct MyQuery; "#; let parsed: syn::DeriveInput = syn::parse_str(input).unwrap(); + let parsed = extract_tokens(&parsed).unwrap(); assert_eq!( extract_attr_list(&parsed, "extern_enums").ok().unwrap(), @@ -309,6 +312,7 @@ mod test { struct MyQuery; "#; let parsed: syn::DeriveInput = syn::parse_str(input).unwrap(); + let parsed = extract_tokens(&parsed).unwrap(); assert_eq!( extract_attr_list(&parsed, "variable_types").ok().unwrap(), @@ -329,6 +333,7 @@ mod test { struct MyQuery; "#; let parsed: syn::DeriveInput = syn::parse_str(input).unwrap(); + let parsed = extract_tokens(&parsed).unwrap(); assert_eq!( extract_attr(&parsed, "response_type").ok().unwrap(), diff --git a/graphql_query_derive/src/lib.rs b/graphql_query_derive/src/lib.rs index e1314f16..dcd15edf 100644 --- a/graphql_query_derive/src/lib.rs +++ b/graphql_query_derive/src/lib.rs @@ -13,6 +13,39 @@ use std::{ use proc_macro2::TokenStream; +#[proc_macro] +pub fn graphql_queries(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + match graphql_queries_inner(input) { + Ok(ts) => ts, + Err(err) => err.to_compile_error().into(), + } +} + +fn graphql_queries_inner( + input: proc_macro::TokenStream, +) -> Result { + let tokens = TokenStream::from(input); + let (query_path, schema_path) = build_query_and_schema_path(&tokens)?; + let mut options = build_graphql_client_derive_options( + CodegenMode::FunctionLike, + &tokens, + query_path.clone(), + )?; + + if let Ok(module_visibility) = attributes::extract_attr(&tokens, "module_visibility") { + options.set_module_visibility_from_str(&module_visibility); + } + + generate_module_token_stream(query_path, &schema_path, options) + .map(Into::into) + .map_err(|err| { + syn::Error::new_spanned( + tokens, + format!("Failed to generate GraphQLQuery impl: {}", err), + ) + }) +} + #[proc_macro_derive(GraphQLQuery, attributes(graphql))] pub fn derive_graphql_query(input: proc_macro::TokenStream) -> proc_macro::TokenStream { match graphql_query_derive_inner(input) { @@ -25,9 +58,14 @@ fn graphql_query_derive_inner( input: proc_macro::TokenStream, ) -> Result { let input = TokenStream::from(input); - let ast = syn::parse2(input)?; - let (query_path, schema_path) = build_query_and_schema_path(&ast)?; - let options = build_graphql_client_derive_options(&ast, query_path.clone())?; + let ast: syn::DeriveInput = syn::parse2(input)?; + let tokens = attributes::extract_tokens(&ast)?; + let (query_path, schema_path) = build_query_and_schema_path(&tokens)?; + let mut options = + build_graphql_client_derive_options(CodegenMode::Derive, &tokens, query_path.clone())?; + options.set_struct_ident(ast.ident.clone()); + options.set_module_visibility(ast.vis.clone()); + options.set_operation_name(ast.ident.to_string()); generate_module_token_stream(query_path, &schema_path, options) .map(Into::into) @@ -39,7 +77,7 @@ fn graphql_query_derive_inner( }) } -fn build_query_and_schema_path(input: &syn::DeriveInput) -> Result<(PathBuf, PathBuf), syn::Error> { +fn build_query_and_schema_path(input: &TokenStream) -> Result<(PathBuf, PathBuf), syn::Error> { let cargo_manifest_dir = env::var("CARGO_MANIFEST_DIR").map_err(|_err| { syn::Error::new_spanned( input, @@ -56,7 +94,8 @@ fn build_query_and_schema_path(input: &syn::DeriveInput) -> Result<(PathBuf, Pat } fn build_graphql_client_derive_options( - input: &syn::DeriveInput, + mode: CodegenMode, + input: &TokenStream, query_path: PathBuf, ) -> Result { let variables_derives = attributes::extract_attr(input, "variables_derives").ok(); @@ -68,7 +107,7 @@ fn build_graphql_client_derive_options( let custom_variable_types = attributes::extract_attr_list(input, "variable_types").ok(); let custom_response_type = attributes::extract_attr(input, "response_type").ok(); - let mut options = GraphQLClientCodegenOptions::new(CodegenMode::Derive); + let mut options = GraphQLClientCodegenOptions::new(mode); options.set_query_file(query_path); options.set_fragments_other_variant(fragments_other_variant); options.set_skip_serializing_none(skip_serializing_none); @@ -111,9 +150,6 @@ fn build_graphql_client_derive_options( options.set_custom_response_type(custom_response_type); } - options.set_struct_ident(input.ident.clone()); - options.set_module_visibility(input.vis.clone()); - options.set_operation_name(input.ident.to_string()); options.set_serde_path(syn::parse_quote!(graphql_client::_private::serde)); Ok(options) From d458fb1ad5194ffc115b7cab4306685664ebf165 Mon Sep 17 00:00:00 2001 From: Ava Johnson Date: Mon, 15 Jun 2026 00:35:36 -0700 Subject: [PATCH 2/4] codegen: generate common code separate from operation code --- graphql_client_codegen/src/codegen.rs | 47 ++++++++++------ .../src/generated_module.rs | 4 ++ graphql_client_codegen/src/lib.rs | 55 ++++++++++++++++--- graphql_client_codegen/src/query.rs | 16 +++--- 4 files changed, 91 insertions(+), 31 deletions(-) diff --git a/graphql_client_codegen/src/codegen.rs b/graphql_client_codegen/src/codegen.rs index e33bb1b1..02f7ea21 100644 --- a/graphql_client_codegen/src/codegen.rs +++ b/graphql_client_codegen/src/codegen.rs @@ -15,7 +15,7 @@ use quote::{quote, ToTokens}; use selection::*; use std::collections::BTreeMap; -/// The main code generation function. +/// The main code generation function for the inputs, variables, and response types. pub(crate) fn response_for_query( operation_id: OperationId, options: &GraphQLClientCodegenOptions, @@ -23,14 +23,10 @@ pub(crate) fn response_for_query( ) -> Result { let serde = options.serde_path(); - let all_used_types = all_used_types(operation_id, &query); + let all_used_types = all_used_types(&vec![operation_id], &query); let response_derives = render_derives(options.all_response_derives()); let variable_derives = render_derives(options.all_variable_derives()); - let scalar_definitions = generate_scalar_definitions(&all_used_types, options, query); - let enum_definitions = enums::generate_enum_definitions(&all_used_types, options, query); - let fragment_definitions = - generate_fragment_definitions(&all_used_types, &response_derives, options, &query); let input_object_definitions = inputs::generate_input_object_definitions( &all_used_types, options, @@ -48,26 +44,45 @@ pub(crate) fn response_for_query( use #serde::{Serialize, Deserialize}; use super::*; + #(#input_object_definitions)* + + #variables_struct + + #definitions + }; + + Ok(q) +} + +/// The main code generation function for scalars, enums, and fragments. +pub(crate) fn common_for_queries( + all_operations: &[OperationId], + options: &GraphQLClientCodegenOptions, + query: BoundQuery<'_>, +) -> Result { + let all_used_types = all_used_types(all_operations, &query); + let response_derives = render_derives(options.all_response_derives()); + + let scalar_definitions = generate_scalar_definitions(&all_used_types, options, query); + let enum_definitions = enums::generate_enum_definitions(&all_used_types, options, query); + let fragment_definitions = + generate_fragment_definitions(&all_used_types, &response_derives, options, &query); + + let q = quote! { #[allow(dead_code)] - type Boolean = bool; + pub type Boolean = bool; #[allow(dead_code)] - type Float = f64; + pub type Float = f64; #[allow(dead_code)] - type Int = i64; + pub type Int = i64; #[allow(dead_code)] - type ID = String; + pub type ID = String; #(#scalar_definitions)* #(#enum_definitions)* - #(#input_object_definitions)* - - #variables_struct - #(#fragment_definitions)* - - #definitions }; Ok(q) diff --git a/graphql_client_codegen/src/generated_module.rs b/graphql_client_codegen/src/generated_module.rs index c57aa1e7..63d55fe3 100644 --- a/graphql_client_codegen/src/generated_module.rs +++ b/graphql_client_codegen/src/generated_module.rs @@ -30,6 +30,7 @@ pub(crate) struct GeneratedModule<'a> { pub resolved_query: &'a crate::query::Query, pub schema: &'a crate::schema::Schema, pub options: &'a crate::GraphQLClientCodegenOptions, + pub common: &'a TokenStream, } impl GeneratedModule<'_> { @@ -62,6 +63,7 @@ impl GeneratedModule<'_> { let operation_name = self.operation; let operation_name_ident = self.options.normalization().operation(self.operation); let operation_name_ident = Ident::new(&operation_name_ident, Span::call_site()); + let common = &self.common; // Force cargo to refresh the generated code when the query file changes. let query_include = self @@ -99,6 +101,8 @@ impl GeneratedModule<'_> { #query_include + #common + #impls } diff --git a/graphql_client_codegen/src/lib.rs b/graphql_client_codegen/src/lib.rs index 1e220929..ab4418bc 100644 --- a/graphql_client_codegen/src/lib.rs +++ b/graphql_client_codegen/src/lib.rs @@ -26,6 +26,7 @@ mod type_qualifiers; mod tests; pub use crate::codegen_options::{CodegenMode, GraphQLClientCodegenOptions}; +use crate::query::BoundQuery; use std::{collections::BTreeMap, fmt::Display, io}; @@ -146,22 +147,60 @@ fn generate_module_token_stream_inner( } }; - // The generated modules. - let mut modules = Vec::with_capacity(operations.len()); + let common = crate::codegen::common_for_queries( + &operations.iter().map(|o| o.0).collect::>(), + &options, + BoundQuery { + query: &query, + schema, + }, + )?; - for operation in &operations { - let generated = generated_module::GeneratedModule { + // The generated modules. + let modules = if operations.len() == 1 { + let module = generated_module::GeneratedModule { query_string: query_string.as_str(), schema, resolved_query: &query, - operation: &operation.1.name, + operation: &operations[0].1.name, options: &options, + common: &common, } .to_token_stream()?; - modules.push(generated); - } - let modules = quote! { #(#modules)* }; + quote! { #module } + } else { + let use_common = quote! { use super::common::*; }; + + let mut modules = Vec::with_capacity(operations.len()); + + for operation in &operations { + let generated = generated_module::GeneratedModule { + query_string: query_string.as_str(), + schema, + resolved_query: &query, + operation: &operation.1.name, + options: &options, + common: &use_common, + } + .to_token_stream()?; + modules.push(generated); + } + + let module_visibility = &options.module_visibility(); + let serde = options.serde_path(); + + quote! { + #module_visibility mod common { + use #serde::{Serialize, Deserialize}; + use super::*; + + #common + } + + #(#modules)* + } + }; Ok(modules) } diff --git a/graphql_client_codegen/src/query.rs b/graphql_client_codegen/src/query.rs index 71d0798f..3f2c23af 100644 --- a/graphql_client_codegen/src/query.rs +++ b/graphql_client_codegen/src/query.rs @@ -707,17 +707,19 @@ pub(crate) fn walk_operation_variables( .filter(move |(_id, var)| var.operation_id == operation_id) } -pub(crate) fn all_used_types(operation_id: OperationId, query: &BoundQuery<'_>) -> UsedTypes { +pub(crate) fn all_used_types(all_operations: &[OperationId], query: &BoundQuery<'_>) -> UsedTypes { let mut used_types = UsedTypes::default(); - let operation = query.query.get_operation(operation_id); + for operation_id in all_operations { + let operation = query.query.get_operation(*operation_id); - for (_id, selection) in query.query.walk_selection_set(&operation.selection_set) { - selection.collect_used_types(&mut used_types, query); - } + for (_id, selection) in query.query.walk_selection_set(&operation.selection_set) { + selection.collect_used_types(&mut used_types, query); + } - for (_id, variable) in walk_operation_variables(operation_id, query.query) { - variable.collect_used_types(&mut used_types, query.schema); + for (_id, variable) in walk_operation_variables(*operation_id, query.query) { + variable.collect_used_types(&mut used_types, query.schema); + } } used_types From 94e7adca8bc5f6f2b96a221d4e5079d6d6b0ff3a Mon Sep 17 00:00:00 2001 From: Ava Johnson Date: Mon, 15 Jun 2026 16:35:17 -0700 Subject: [PATCH 3/4] Add tests --- graphql_client/tests/module_visibility.rs | 14 ++++++++++++++ .../tests/module_visibility/query.graphql | 3 +++ .../tests/module_visibility/schema.graphql | 7 +++++++ graphql_client/tests/shared_fragments.rs | 12 ++++++++++++ .../tests/shared_fragments/query.graphql | 11 +++++++++++ .../tests/shared_fragments/schema.graphql | 8 ++++++++ 6 files changed, 55 insertions(+) create mode 100644 graphql_client/tests/module_visibility.rs create mode 100644 graphql_client/tests/module_visibility/query.graphql create mode 100644 graphql_client/tests/module_visibility/schema.graphql create mode 100644 graphql_client/tests/shared_fragments.rs create mode 100644 graphql_client/tests/shared_fragments/query.graphql create mode 100644 graphql_client/tests/shared_fragments/schema.graphql diff --git a/graphql_client/tests/module_visibility.rs b/graphql_client/tests/module_visibility.rs new file mode 100644 index 00000000..0254545e --- /dev/null +++ b/graphql_client/tests/module_visibility.rs @@ -0,0 +1,14 @@ +mod inner { + use graphql_client::*; + + graphql_queries!( + query_path = "tests/module_visibility/query.graphql", + schema_path = "tests/module_visibility/schema.graphql", + module_visibility = "pub" + ); +} + +#[test] +fn module_visibility() { + let _ = inner::test_query::ResponseData { value: None }; +} diff --git a/graphql_client/tests/module_visibility/query.graphql b/graphql_client/tests/module_visibility/query.graphql new file mode 100644 index 00000000..e087e1e1 --- /dev/null +++ b/graphql_client/tests/module_visibility/query.graphql @@ -0,0 +1,3 @@ +query TestQuery { + value +} diff --git a/graphql_client/tests/module_visibility/schema.graphql b/graphql_client/tests/module_visibility/schema.graphql new file mode 100644 index 00000000..40e85dae --- /dev/null +++ b/graphql_client/tests/module_visibility/schema.graphql @@ -0,0 +1,7 @@ +schema { + query: QueryRoot +} + +type QueryRoot { + value: String +} diff --git a/graphql_client/tests/shared_fragments.rs b/graphql_client/tests/shared_fragments.rs new file mode 100644 index 00000000..c23ab92a --- /dev/null +++ b/graphql_client/tests/shared_fragments.rs @@ -0,0 +1,12 @@ +use graphql_client::*; + +graphql_queries!( + query_path = "tests/shared_fragments/query.graphql", + schema_path = "tests/shared_fragments/schema.graphql", +); + +#[test] +fn shared_fragments() { + let _: common::FragmentReference = a::ResponseData { in_fragment: None }; + let _: a::ResponseData = b::ResponseData { in_fragment: None }; +} diff --git a/graphql_client/tests/shared_fragments/query.graphql b/graphql_client/tests/shared_fragments/query.graphql new file mode 100644 index 00000000..b5036ac8 --- /dev/null +++ b/graphql_client/tests/shared_fragments/query.graphql @@ -0,0 +1,11 @@ +fragment FragmentReference on QueryRoot { + inFragment +} + +query A { + ...FragmentReference +} + +query B { + ...FragmentReference +} diff --git a/graphql_client/tests/shared_fragments/schema.graphql b/graphql_client/tests/shared_fragments/schema.graphql new file mode 100644 index 00000000..6070bb0f --- /dev/null +++ b/graphql_client/tests/shared_fragments/schema.graphql @@ -0,0 +1,8 @@ +schema { + query: QueryRoot +} + +type QueryRoot { + extra: String + inFragment: String +} From 4cd2a2ebfa2d573421189053b5caeee172af2def Mon Sep 17 00:00:00 2001 From: Ava Johnson Date: Sat, 20 Jun 2026 16:03:39 -0700 Subject: [PATCH 4/4] Fix `custom_scalars_module` --- .../tests/custom_scalars/query.graphql | 4 +++ graphql_client/tests/custom_scalars_module.rs | 36 +++++++++++++++++++ graphql_client_codegen/src/codegen.rs | 4 +-- 3 files changed, 42 insertions(+), 2 deletions(-) create mode 100644 graphql_client/tests/custom_scalars_module.rs diff --git a/graphql_client/tests/custom_scalars/query.graphql b/graphql_client/tests/custom_scalars/query.graphql index d8b7b69c..79f13f17 100644 --- a/graphql_client/tests/custom_scalars/query.graphql +++ b/graphql_client/tests/custom_scalars/query.graphql @@ -1,3 +1,7 @@ query CustomScalarQuery { address } + +query AnotherCustomScalarQuery { + address +} diff --git a/graphql_client/tests/custom_scalars_module.rs b/graphql_client/tests/custom_scalars_module.rs new file mode 100644 index 00000000..9cb72ce8 --- /dev/null +++ b/graphql_client/tests/custom_scalars_module.rs @@ -0,0 +1,36 @@ +use graphql_client::*; +use serde_json::json; + +use std::net::Ipv4Addr; + +mod scalars { + // Important! The NetworkAddress scalar should deserialize to an Ipv4Addr from the Rust std library. + pub type NetworkAddress = super::Ipv4Addr; +} + +graphql_queries!( + query_path = "tests/custom_scalars/query.graphql", + schema_path = "tests/custom_scalars/schema.graphql" + custom_scalars_module = "scalars" +); + +#[test] +fn custom_scalars() { + let valid_response = json!({ + "address": "127.0.1.2", + }); + + let valid_addr = + serde_json::from_value::(valid_response).unwrap(); + + assert_eq!( + valid_addr.address.unwrap(), + "127.0.1.2".parse::().unwrap() + ); + + let invalid_response = json!({ + "address": "localhost", + }); + + assert!(serde_json::from_value::(invalid_response).is_err()); +} diff --git a/graphql_client_codegen/src/codegen.rs b/graphql_client_codegen/src/codegen.rs index 02f7ea21..055c5e85 100644 --- a/graphql_client_codegen/src/codegen.rs +++ b/graphql_client_codegen/src/codegen.rs @@ -186,9 +186,9 @@ fn generate_scalar_definitions<'a, 'schema: 'a>( ); if let Some(custom_scalars_module) = options.custom_scalars_module() { - quote!(type #ident = #custom_scalars_module::#ident;) + quote!(pub type #ident = #custom_scalars_module::#ident;) } else { - quote!(type #ident = super::#ident;) + quote!(pub type #ident = super::#ident;) } }) }