Skip to content

Commit fda50f4

Browse files
authored
Add support for grouped method parameters (#897)
* Add support for grouped method parameters The override decorator allows grouping parameters into a singular model to be passed to the method. Note that only required parameters will be grouped; any optional parameters will continue to be placed in the method's default method options type. Added ParameterGroup to the code model. Parameters now have an optional group field indicating if they belong to a parameter group. The MethodOptions type was replaced by ParameterGroup as there's overlap between the two constructs. This also enabled the removal of some hard coded bits when constructing the method options param instance. * fix doc comment
1 parent 5ca8235 commit fda50f4

25 files changed

Lines changed: 962 additions & 102 deletions

File tree

packages/typespec-rust/.scripts/tspcompile.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ const azureHttpSpecsGroup = {
8989
'spector_emptystringasnone': {input: 'azure/client-generator-core/deserialize-empty-string-as-null'},
9090
'spector_flattenproperty': {input: 'azure/client-generator-core/flatten-property'},
9191
'spector_corenextlinkverb': {input: 'azure/client-generator-core/next-link-verb'},
92-
//'spector_coreoverride': {input: 'azure/client-generator-core/override/client.tsp'},
92+
'spector_coreoverride': {input: 'azure/client-generator-core/override/client.tsp'},
9393
'spector_coreusage': {input: 'azure/client-generator-core/usage'},
9494
'spector_basic': {input: 'azure/core/basic'},
9595
'spector_lrorpc': {input: 'azure/core/lro/rpc'},

packages/typespec-rust/CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
### Features Added
1313

1414
* Define a `pub(crate)` constant for `api-version` to use in hand-authored `Default` implementation on client options.
15-
* Add support for TypeSpec `union` types.
15+
* Added support for the following.
16+
* TypeSpec `union` types.
17+
* Grouped method parameters via the `@override` decorator.
1618

1719
### Other Changes
1820

packages/typespec-rust/src/codegen/clients.ts

Lines changed: 106 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,16 @@ export function emitClients(module: rust.ModuleContainer): ClientModules | undef
316316
for (const method of client.methods) {
317317
if (method.kind !== 'clientaccessor') {
318318
// client method options types are always in the same module as their client method
319-
use.add(`${utils.buildImportPath(client.module, client.module)}::models`, method.options.type.name);
319+
use.add(`${utils.buildImportPath(client.module, client.module)}::models`, method.options.type.type.name);
320+
321+
// add imports for parameter group types
322+
const seenGroups = new Set<string>();
323+
for (const param of method.params) {
324+
if (param.group && !seenGroups.has(param.group.type.name)) {
325+
seenGroups.add(param.group.type.name);
326+
use.add(`${utils.buildImportPath(client.module, client.module)}::models`, param.group.type.name);
327+
}
328+
}
320329
}
321330
}
322331

@@ -354,34 +363,35 @@ function getMethodOptions(module: rust.ModuleContainer): helpers.Module | undefi
354363

355364
// method options struct
356365
let block = '';
357-
block += helpers.formatDocComment(method.options.type.docs);
366+
const optionsStruct = method.options.type.type;
367+
block += helpers.formatDocComment(optionsStruct.docs);
358368
use.add('azure_core::fmt', 'SafeDebug');
359369
block += '#[derive(Clone, Default, SafeDebug)]\n';
360-
block += `${helpers.emitVisibility(method.options.type.visibility)}struct ${helpers.getTypeDeclaration(method.options.type)} {\n`;
361-
visTracker.update(method.options.type.visibility);
362-
for (let i = 0; i < method.options.type.fields.length; ++i) {
363-
const field = method.options.type.fields[i];
370+
block += `${helpers.emitVisibility(optionsStruct.visibility)}struct ${helpers.getTypeDeclaration(optionsStruct)} {\n`;
371+
visTracker.update(optionsStruct.visibility);
372+
for (let i = 0; i < optionsStruct.fields.length; ++i) {
373+
const field = optionsStruct.fields[i];
364374
use.addForType(field.type);
365375
const fieldDocs = helpers.formatDocComment(field.docs);
366376
if (fieldDocs.length > 0) {
367377
block += `${indent.get()}${fieldDocs}`;
368378
}
369379
block += `${indent.get()}${helpers.emitVisibility(method.visibility)}${field.name}: ${helpers.getTypeDeclaration(field.type)},\n`;
370-
if (i + 1 < method.options.type.fields.length) {
380+
if (i + 1 < optionsStruct.fields.length) {
371381
block += '\n';
372382
}
373383
}
374384
block += '}\n';
375385

376386
if (method.kind === 'pageable' || method.kind === 'lro') {
377387
block += '\n';
378-
block += `impl ${helpers.getTypeDeclaration(method.options.type, 'anonymous')} {\n`;
379-
const wrappedTypeName = helpers.wrapInBackTicks(helpers.getTypeDeclaration(method.options.type, 'omit'));
388+
block += `impl ${helpers.getTypeDeclaration(optionsStruct, 'anonymous')} {\n`;
389+
const wrappedTypeName = helpers.wrapInBackTicks(helpers.getTypeDeclaration(optionsStruct, 'omit'));
380390
block += `${indent.get()}/// Transforms this [${wrappedTypeName}] into a new ${wrappedTypeName} that owns the underlying data, cloning it if necessary.\n`;
381-
block += `${indent.get()}pub fn into_owned(self) -> ${method.options.type.name}<'static> {\n`;
382-
block += `${indent.push().get()}${method.options.type.name} {\n`;
391+
block += `${indent.get()}pub fn into_owned(self) -> ${optionsStruct.name}<'static> {\n`;
392+
block += `${indent.push().get()}${optionsStruct.name} {\n`;
383393
indent.push();
384-
for (const field of method.options.type.fields) {
394+
for (const field of optionsStruct.fields) {
385395
if (field.type.kind === 'clientMethodOptions' || field.type.kind === 'pagerOptions' || field.type.kind === 'pollerOptions') {
386396
block += `${indent.get()}${field.name}: ${field.type.name} {\n`;
387397
block += `${indent.push().get()}context: self.${field.name}.context.into_owned(),\n`;
@@ -396,7 +406,39 @@ function getMethodOptions(module: rust.ModuleContainer): helpers.Module | undefi
396406
block += '}\n';
397407
}
398408

399-
structBlocks.push({ name: method.options.type.name, body: block });
409+
structBlocks.push({ name: optionsStruct.name, body: block });
410+
411+
// parameter group structs
412+
const seenGroups = new Set<rust.ParameterGroup<rust.Struct>>();
413+
for (const param of method.params) {
414+
if (param.group) {
415+
seenGroups.add(param.group);
416+
}
417+
}
418+
419+
for (const group of seenGroups) {
420+
const groupParams = method.params.filter(p => p.group === group);
421+
let groupBlock = '';
422+
groupBlock += helpers.formatDocComment(group.type.docs);
423+
use.add('azure_core::fmt', 'SafeDebug');
424+
groupBlock += '#[derive(Clone, SafeDebug)]\n';
425+
groupBlock += `${helpers.emitVisibility(group.type.visibility)}struct ${group.type.name}${group.type.lifetime ? `<${group.type.lifetime.name}>` : ''} {\n`;
426+
visTracker.update(group.type.visibility);
427+
for (let i = 0; i < groupParams.length; i++) {
428+
const field = groupParams[i];
429+
use.addForType(field.type);
430+
const fieldDocs = helpers.formatDocComment(field.docs);
431+
if (fieldDocs.length > 0) {
432+
if (i > 0) {
433+
groupBlock += '\n';
434+
}
435+
groupBlock += `${indent.get()}${fieldDocs}`;
436+
}
437+
groupBlock += `${indent.get()}${helpers.emitVisibility(group.type.visibility)}${field.name}: ${helpers.getTypeDeclaration(field.type)},\n`;
438+
}
439+
groupBlock += '}\n';
440+
structBlocks.push({ name: group.type.name, body: groupBlock });
441+
}
400442
}
401443
}
402444

@@ -433,7 +475,18 @@ function getParamsBlockDocComment(indent: helpers.indentation, callable: rust.Co
433475
};
434476

435477
let paramsContent = '';
478+
const documentedGroups = new Set<string>();
436479
for (const param of callable.params) {
480+
if ('group' in param && param.group) {
481+
// required parameter group appears as a single param in the method sig
482+
const group = param.group;
483+
if (!documentedGroups.has(group.name)) {
484+
documentedGroups.add(group.name);
485+
paramsContent += helpers.formatDocComment(group.docs, false, formatParamBullet(group.name), indent);
486+
}
487+
continue;
488+
}
489+
437490
let optional = false;
438491
if ('optional' in param) {
439492
optional = param.optional;
@@ -517,7 +570,18 @@ function getMethodParamsCountAndSig(method: rust.MethodType, use: Use): { count:
517570
++count;
518571
}
519572
} else {
573+
const emittedGroups = new Set<string>();
520574
for (const param of method.params) {
575+
if (param.group) {
576+
if (!emittedGroups.has(param.group.name)) {
577+
emittedGroups.add(param.group.name);
578+
// required parameter group appears as a single struct parameter in the method signature
579+
paramsSig.push(`${param.group.name}: ${helpers.getTypeDeclaration(param.group.type, 'anonymous')}`);
580+
++count;
581+
}
582+
continue;
583+
}
584+
521585
const paramType = helpers.unwrapType(param.type);
522586
if (paramType.kind === 'literal') {
523587
// literal params are embedded directly in the code (e.g. accept header param)
@@ -536,7 +600,7 @@ function getMethodParamsCountAndSig(method: rust.MethodType, use: Use): { count:
536600
}
537601
}
538602

539-
paramsSig.push(`options: ${helpers.getTypeDeclaration(method.options, 'anonymous')}`);
603+
paramsSig.push(`${method.options.name}: ${helpers.getTypeDeclaration(method.options.type, 'anonymous')}`);
540604
++count;
541605
}
542606

@@ -717,7 +781,6 @@ function getMethodParamGroup(method: ClientMethod): MethodParamGroups {
717781
const pathParams = new Array<PathParamType>();
718782
const queryParams = new Array<QueryParamType>();
719783
const partialBodyParams = new Array<rust.PartialBodyParameter>();
720-
721784
for (const param of method.params) {
722785
switch (param.kind) {
723786
case 'headerScalar':
@@ -839,8 +902,9 @@ function constructUrl(indent: helpers.indentation, use: Use, method: ClientMetho
839902
let wrapSortedVec: (s: string) => string = (s) => s;
840903
let paramExpression: string;
841904
if (pathParam.kind === 'pathHashMap') {
905+
const pathParamRef = qualifiedParamName(pathParam);
842906
wrapSortedVec = (s) => `${indent.get()}{`
843-
+ `${indent.push().get()}let mut ${pathParam.name}_vec = ${pathParam.name}.iter().collect::<Vec<_>>();\n`
907+
+ `${indent.push().get()}let mut ${pathParam.name}_vec = ${pathParamRef}.iter().collect::<Vec<_>>();\n`
844908
+ `${indent.get()}${pathParam.name}_vec.sort_by_key(|p| p.0);\n`
845909
+ `${s}`
846910
+ `${indent.pop().get()}}`;
@@ -875,16 +939,17 @@ function constructUrl(indent: helpers.indentation, use: Use, method: ClientMetho
875939
break;
876940
}
877941
} else if (pathParam.kind === 'pathCollection') {
878-
paramExpression = `&${pathParam.name}.join(",")`;
942+
const pathParamRef = qualifiedParamName(pathParam);
943+
paramExpression = `&${pathParamRef}.join(",")`;
879944
switch (pathParam.style) {
880945
case 'path':
881-
paramExpression = `&format!("/{}", ${pathParam.name}.join("${pathParam.explode ? '/' : ','}"))`;
946+
paramExpression = `&format!("/{}", ${pathParamRef}.join("${pathParam.explode ? '/' : ','}"))`;
882947
break;
883948
case 'label':
884-
paramExpression = `&format!(".{}", ${pathParam.name}.join("${pathParam.explode ? '.' : ','}"))`;
949+
paramExpression = `&format!(".{}", ${pathParamRef}.join("${pathParam.explode ? '.' : ','}"))`;
885950
break;
886951
case 'matrix':
887-
paramExpression = `&format!(";${pathParam.name}={}", ${pathParam.name}.join(`
952+
paramExpression = `&format!(";${pathParam.name}={}", ${pathParamRef}.join(`
888953
+ `"${pathParam.explode ? `;${pathParam.name}=` : ','}"))`;
889954
break;
890955
}
@@ -961,8 +1026,9 @@ function constructUrl(indent: helpers.indentation, use: Use, method: ClientMetho
9611026
for (const queryParam of paramGroups.query) {
9621027
if (queryParam.kind === 'queryCollection' && queryParam.format === 'multi') {
9631028
body += getParamValueHelper(indent, queryParam, () => {
1029+
const queryParamRef = qualifiedParamName(queryParam);
9641030
const valueVar = queryParam.name[0];
965-
let text = `${indent.get()}for ${valueVar} in ${queryParam.name}.iter() {\n`;
1031+
let text = `${indent.get()}for ${valueVar} in ${queryParamRef}.iter() {\n`;
9661032
// if queryParam is a &[&str] then we'll need to deref the iterator
9671033
const deref = utils.asTypeOf(queryParam.type, 'str', 'ref', 'slice', 'ref') ? '*' : '';
9681034
text += `${indent.push().get()}query_builder.append_pair("${queryParam.key}", ${deref}${getHeaderPathQueryParamValue(use, queryParam, !queryParam.optional, false, valueVar)});\n`;
@@ -971,8 +1037,9 @@ function constructUrl(indent: helpers.indentation, use: Use, method: ClientMetho
9711037
});
9721038
} else if (queryParam.kind === 'queryHashMap') {
9731039
body += getParamValueHelper(indent, queryParam, () => {
1040+
const queryParamRef = qualifiedParamName(queryParam);
9741041
let text = `${indent.get()}{\n`;
975-
text += `${indent.push().get()}let mut ${queryParam.name}_vec = ${queryParam.name}.iter().collect::<Vec<_>>();\n`;
1042+
text += `${indent.push().get()}let mut ${queryParam.name}_vec = ${queryParamRef}.iter().collect::<Vec<_>>();\n`;
9761043
text += `${indent.get()}${queryParam.name}_vec.sort_by_key(|p| p.0);\n`;
9771044
if (queryParam.explode) {
9781045
text += `${indent.get()}for (k, v) in ${queryParam.name}_vec.iter() {\n`;
@@ -1044,7 +1111,8 @@ function applyHeaderParams(indent: helpers.indentation, use: Use, method: Client
10441111

10451112
body += getParamValueHelper(indent, headerParam, () => {
10461113
if (headerParam.kind === 'headerHashMap') {
1047-
let setter = `for (k, v) in ${headerParam.name} {\n`;
1114+
const headerParamRef = qualifiedParamName(headerParam);
1115+
let setter = `for (k, v) in ${headerParamRef} {\n`;
10481116
setter += `${indent.push().get()}${requestVarName}.insert_header(format!("${headerParam.header}-{k}"), v);\n`;
10491117
setter += `${indent.pop().get()}}\n`;
10501118
return setter;
@@ -1215,8 +1283,9 @@ function emitEmptyPathParamCheck(indent: helpers.indentation, param: PathParamTy
12151283
// no length to check so bail
12161284
return '';
12171285
}
1286+
const paramRef = qualifiedParamName(param);
12181287
return helpers.buildIfBlock(indent, {
1219-
condition: `${param.name}${toString}.is_empty()`,
1288+
condition: `${paramRef}${toString}.is_empty()`,
12201289
body: (indent) => `${indent.get()}return Err(azure_core::Error::with_message(azure_core::error::ErrorKind::Other, "parameter ${param.name} cannot be empty"));\n`,
12211290
});
12221291
}
@@ -1852,6 +1921,8 @@ function getHeaderPathQueryParamValue(use: Use, param: HeaderParamType | PathPar
18521921
paramName = 'self.' + paramName;
18531922
} else if (overrideParamName) {
18541923
paramName = overrideParamName;
1924+
} else if (param.group) {
1925+
paramName = qualifiedParamName(param);
18551926
}
18561927

18571928
const encodeBytes = function (type: rust.EncodedBytes, param?: string): string {
@@ -1886,6 +1957,8 @@ function getHeaderPathQueryParamValue(use: Use, param: HeaderParamType | PathPar
18861957
// param requires borrowing.
18871958
let mustBorrow = !helpers.isQueryParameter(param);
18881959

1960+
const isGrouped = !!param.group;
1961+
18891962
const paramType = helpers.unwrapType(param.type);
18901963
// we want multi to hit the else case so the necessary conversions etc can happen
18911964
if ((param.kind === 'headerCollection' || param.kind === 'queryCollection') && param.format !== 'multi') {
@@ -1911,8 +1984,8 @@ function getHeaderPathQueryParamValue(use: Use, param: HeaderParamType | PathPar
19111984
switch (paramType.kind) {
19121985
case 'String':
19131986
paramValue = paramName;
1914-
// if the param is on the client, then we must borrow
1915-
mustBorrow = param.location === 'client' && fromSelf;
1987+
// if the param is on the client or in a group struct, then we must borrow
1988+
mustBorrow = isGrouped || (param.location === 'client' && fromSelf);
19161989
break;
19171990
case 'str':
19181991
paramValue = paramName;
@@ -1956,8 +2029,8 @@ function getHeaderPathQueryParamValue(use: Use, param: HeaderParamType | PathPar
19562029
case 'headerHashMap':
19572030
case 'headerScalar':
19582031
// for non-copyable params (e.g. String), we need to borrow them if they're on the
1959-
// client or we're in a closure and the param is required (header params are always owned)
1960-
mustBorrow = nonCopyableType(param.type) && (param.location === 'client' || (!fromSelf && !param.optional));
2032+
// client, in a group struct, or we're in a closure and the param is required (header params are always owned)
2033+
mustBorrow = nonCopyableType(param.type) && (isGrouped || param.location === 'client' || (!fromSelf && !param.optional));
19612034
break;
19622035
}
19632036

@@ -2037,3 +2110,8 @@ function isEnumString(type: rust.Type): type is rust.Enum {
20372110
const unwrapped = helpers.unwrapType(type);
20382111
return unwrapped.kind === 'enum' && unwrapped.type === 'String';
20392112
}
2113+
2114+
/** returns the qualified name for a param, prefixing with the group name when the param belongs to a parameter group */
2115+
function qualifiedParamName(param: rust.MethodParameter): string {
2116+
return param.group ? `${param.group.name}.${param.name}` : param.name;
2117+
}

packages/typespec-rust/src/codegen/codeGenerator.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ function sortContent(content: rust.ModuleContainer): void {
192192
} else if (method.kind === 'pageable' && method.strategy?.kind === 'nextLink') {
193193
method.strategy.reinjectedParams.sort((a: rust.MethodParameter, b: rust.MethodParameter) => sortAscending(a.name, b.name));
194194
}
195-
method.options.type.fields.sort((a: rust.StructField, b: rust.StructField) => { return sortAscending(a.name, b.name); });
195+
method.options.type.type.fields.sort((a: rust.StructField, b: rust.StructField) => { return sortAscending(a.name, b.name); });
196196
method.responseHeaders?.headers.sort((a: rust.ResponseHeader, b: rust.ResponseHeader) => sortAscending(a.header, b.header));
197197
}
198198
}

packages/typespec-rust/src/codegen/helpers.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ export function getTypeDeclaration(type: rust.Client | rust.ResponseHeadersTrait
235235
// we explicitly omit the Response<T> from the type decl
236236
return `Poller<${getTypeDeclaration(type.type.content, withLifetime)}>`;
237237
case 'ref':
238-
return `&${getTypeDeclaration(type.type)}`;
238+
return `&${type.lifetime ? `${type.lifetime.name} ` : ''}${getTypeDeclaration(type.type)}`;
239239
case 'requestContent': {
240240
const formatType = `${type.format !== 'JsonFormat' ? `, ${type.format}` : ''}`;
241241
return `${type.name}<${getTypeDeclaration(type.content, withLifetime)}${formatType}>`;

0 commit comments

Comments
 (0)