Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,21 @@ protected override IReadOnlyList<MethodBodyStatement> BuildAttributes()
var buildableProviders = new HashSet<TypeProvider>(s_typeProviderNameComparer);
var buildableTypes = new HashSet<CSharpType>(s_cSharpTypeNameComparer);

// Process all providers from the output library to discover types from methods and properties
var providers = ScmCodeModelGenerator.Instance.OutputLibrary.TypeProviders;
// Base-model traversal can encounter equivalent provider instances that are not reference-equal to
// the output-library roots, so keep the output-library provider set name-comparable.
var contextEligibleOutputProviders = new HashSet<TypeProvider>(
ScmCodeModelGenerator.Instance.OutputLibrary.TypeProviders,
s_typeProviderNameComparer);

// Process each provider recursively
foreach (var provider in providers)
// Process each output-library provider recursively to discover types from methods and properties.
foreach (var provider in contextEligibleOutputProviders)
{
// Only output-library providers get standalone context entries.
if (ImplementsModelReaderWriter(provider))
{
buildableProviders.Add(provider);
}

CollectBuildableTypeProvidersRecursive(provider, visitedTypes, visitedTypeProviders, buildableProviders, buildableTypes);
}

Expand Down Expand Up @@ -131,12 +140,6 @@ private void CollectBuildableTypeProvidersRecursive(
return;
}

// Only add to buildableProviders if it implements MRW
if (ImplementsModelReaderWriter(currentProvider))
{
buildableProviders.Add(currentProvider);
}

// Process all providers to discover types from methods and properties
if (currentProvider is not null)
{
Expand Down Expand Up @@ -182,8 +185,9 @@ private void CollectBuildableTypesRecursiveCore(

if (provider is ModelProvider modelProvider && modelProvider.BaseModelProvider != null)
{
// For base model types, we need to process their properties as well, but we don't need to add the base model type itself
CollectBuildableTypesRecursiveCore(modelProvider.BaseModelProvider, visitedTypes, visitedTypeProviders, buildableProviders, buildableTypes);
// Traverse base model properties for discoverable types, but do not add the base model
// itself as a context entry unless it was in the output-library seed set.
CollectBuildableTypeProvidersRecursive(modelProvider.BaseModelProvider, visitedTypes, visitedTypeProviders, buildableProviders, buildableTypes);
}
else
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1114,16 +1114,14 @@ private List<MethodBodyStatement> BuildDeserializePropertiesStatements(ScopedApi
var rawBinaryData = _rawDataField;
if (rawBinaryData == null)
{
var baseModelProvider = _model.BaseModelProvider;
while (baseModelProvider != null)
foreach (var baseModelProvider in EnumerateBaseModelProviders())
{
var field = baseModelProvider.Fields.FirstOrDefault(f => f.Name == AdditionalPropertiesHelper.AdditionalBinaryDataPropsFieldName);
if (field != null)
{
rawBinaryData = field;
break;
}
baseModelProvider = baseModelProvider.BaseModelProvider;
}
}

Expand Down Expand Up @@ -1737,8 +1735,7 @@ private MethodBodyStatement[] CreateWritePropertiesStatements(bool isDynamicMode

if (isDynamicModelWithNonDynamicBase)
{
var baseModelProvider = _model.BaseModelProvider;
while (baseModelProvider != null)
foreach (var baseModelProvider in EnumerateBaseModelProviders())
{
foreach (var property in baseModelProvider.CanonicalView.Properties)
{
Expand All @@ -1759,8 +1756,6 @@ private MethodBodyStatement[] CreateWritePropertiesStatements(bool isDynamicMode

propertyStatements.Add(CreateWritePropertyStatement(field.WireInfo, field.Type, field.Name, field, field.WireInfo?.SerializationFormat));
}

baseModelProvider = baseModelProvider.BaseModelProvider;
}
}

Expand Down Expand Up @@ -2575,21 +2570,20 @@ private MethodBodyStatement CreateWriteAdditionalPropertiesStatement()
PropertyProvider? property = _model.Properties.FirstOrDefault(
p => p.BackingField?.Name == AdditionalPropertiesHelper.AdditionalBinaryDataPropsFieldName);
// search in the base model if the property is not found in the current model
return property ?? _model.BaseModelProvider?.Properties.FirstOrDefault(
p => p.BackingField?.Name == AdditionalPropertiesHelper.AdditionalBinaryDataPropsFieldName);
return property ?? EnumerateBaseModelProviders()
.SelectMany(m => m.Properties)
.FirstOrDefault(p => p.BackingField?.Name == AdditionalPropertiesHelper.AdditionalBinaryDataPropsFieldName);
}

private MethodProvider? FindCustomHookMethod(string hookName)
{
var model = _model;
while (model != null)
foreach (var model in EnumerateModelAndBaseModelProviders())
{
var method = model.CanonicalView.Methods.FirstOrDefault(m => m.Signature.Name == hookName);
if (method != null)
{
return method;
}
model = model.BaseModelProvider;
}
return null;
}
Expand Down Expand Up @@ -2636,9 +2630,8 @@ private List<AttributeStatement> GetSerializationAttributes()
List<AttributeStatement> serializationAttributes = _model.CustomCodeView?.Attributes
.Where(a => a.Type.Name == CodeGenAttributes.CodeGenSerializationAttributeName)
.ToList() ?? [];
var baseModelProvider = _model.BaseModelProvider;

while (baseModelProvider != null)
foreach (var baseModelProvider in EnumerateBaseModelProviders())
{
var customCodeView = baseModelProvider.CustomCodeView;
if (customCodeView != null)
Expand All @@ -2647,12 +2640,36 @@ private List<AttributeStatement> GetSerializationAttributes()
.AddRange(customCodeView.Attributes
.Where(a => a.Type.Name == CodeGenAttributes.CodeGenSerializationAttributeName));
}
baseModelProvider = baseModelProvider.BaseModelProvider;
}

return serializationAttributes;
}

private IEnumerable<ModelProvider> EnumerateModelAndBaseModelProviders()
{
// Custom code can create base-model cycles; stop at the first repeated provider.
var visited = new HashSet<ModelProvider>();
var model = _model;
while (model != null && visited.Add(model))
{
yield return model;
model = model.BaseModelProvider;
}
}

private IEnumerable<ModelProvider> EnumerateBaseModelProviders()
{
// Custom code can create base-model cycles; include the current model in the visited set so
// a cycle back to it is not yielded as one of its own bases.
var visited = new HashSet<ModelProvider> { _model };
var model = _model.BaseModelProvider;
while (model != null && visited.Add(model))
{
yield return model;
model = model.BaseModelProvider;
}
}

private static bool TypeRequiresNullCheckInSerialization(CSharpType type)
{
if (type.IsCollection)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ protected override FormattableString BuildDescription()
private readonly Type _additionalPropsUnknownType = typeof(BinaryData);
private Lazy<bool> _useObjectAdditionalProperties;
private FieldProvider? _rawDataField;
private bool _buildingRawDataField;
private List<FieldProvider>? _additionalPropertyFields;
private List<PropertyProvider>? _additionalPropertyProperties;
private ModelProvider? _baseModelProvider;
Expand Down Expand Up @@ -178,7 +179,33 @@ public override void Reset()
_fullConstructor = null;
}

protected FieldProvider? RawDataField => _rawDataField ??= BuildRawDataField();
protected FieldProvider? RawDataField
{
get
{
if (_rawDataField is not null)
{
return _rawDataField;
}

if (_buildingRawDataField)
{
// BuildRawDataField walks base models and can re-enter this property when custom
// base models form a cycle.
return null;
}

_buildingRawDataField = true;
try
{
return _rawDataField = BuildRawDataField();
}
finally
{
_buildingRawDataField = false;
}
}
}
protected virtual bool ShouldSkipDerivedModelProperties => false;
/// <summary>
/// Gets whether derived models should skip overriding serialization methods from this base model.
Expand Down Expand Up @@ -648,8 +675,11 @@ private IEnumerable<InputModelType> EnumerateBaseModels()

private IEnumerable<ModelProvider> EnumerateBaseModelProviders()
{
// Custom code can create base-model cycles; include this model in the visited set so a cycle
// back to it is not yielded as one of its own bases.
HashSet<ModelProvider> visited = [this];
var model = BaseModelProvider;
while (model != null)
while (model != null && visited.Add(model))
{
yield return model;
model = model.BaseModelProvider;
Expand Down Expand Up @@ -859,9 +889,8 @@ private bool ParametersMatch(IReadOnlyList<ParameterProvider> params1, IReadOnly
private IEnumerable<PropertyProvider> GetAllBasePropertiesForConstructorInitialization(bool includeAllHierarchyDiscriminator = false)
{
var properties = new Stack<List<PropertyProvider>>();
var modelProvider = BaseModelProvider;
bool isDirectBase = true;
while (modelProvider != null)
foreach (var modelProvider in EnumerateBaseModelProviders())
{
properties.Push([]);
foreach (var property in modelProvider.CanonicalView.Properties)
Expand All @@ -880,7 +909,6 @@ private IEnumerable<PropertyProvider> GetAllBasePropertiesForConstructorInitiali
}
}

modelProvider = modelProvider.BaseModelProvider;
isDirectBase = false;
}

Expand All @@ -891,15 +919,13 @@ private IEnumerable<PropertyProvider> GetAllBasePropertiesForConstructorInitiali
private IEnumerable<FieldProvider> GetAllBaseFieldsForConstructorInitialization()
{
var fields = new Stack<List<FieldProvider>>();
var modelProvider = BaseModelProvider;
while (modelProvider != null)
foreach (var modelProvider in EnumerateBaseModelProviders())
{
fields.Push([]);
foreach (var field in modelProvider.CanonicalView.Fields)
{
fields.Peek().Add(field);
}
modelProvider = modelProvider.BaseModelProvider;
}

return fields.SelectMany(l => l);
Expand All @@ -918,7 +944,7 @@ private IEnumerable<FieldProvider> GetAllBaseFieldsForConstructorInitialization(
baseProperties = GetAllBasePropertiesForConstructorInitialization(includeDiscriminatorParameter);
baseFields = GetAllBaseFieldsForConstructorInitialization();
}
else if (BaseModelProvider?.FullConstructor.Signature != null)
else if (BaseModelProvider is not null && !HasBaseModelProviderCycle())
{
baseParameters.AddRange(BaseModelProvider.FullConstructor.Signature.Parameters);
}
Expand Down Expand Up @@ -1003,6 +1029,25 @@ p.Property is null
return (constructorParameters, constructorInitializer);
}

private bool HasBaseModelProviderCycle()
{
// FullConstructor reads the base constructor signature. If the custom base chain loops back
// to this model, skip that read rather than recursively building this constructor again.
HashSet<ModelProvider> visited = [this];
var modelProvider = BaseModelProvider;
while (modelProvider != null)
{
if (!visited.Add(modelProvider))
{
return true;
}

modelProvider = modelProvider.BaseModelProvider;
}

return false;
}

private ValueExpression? EnsureDiscriminatorValueExpression()
{
if (_inputModel.BaseModel is not null && _inputModel.DiscriminatorValue is not null)
Expand Down Expand Up @@ -1300,14 +1345,12 @@ private static ValueExpression GetConversion(PropertyProvider? property = defaul
}

// check if there is a raw data field on any of the base models, if so, we do not have to have one here.
var baseModelProvider = BaseModelProvider;
while (baseModelProvider != null)
foreach (var baseModelProvider in EnumerateBaseModelProviders())
{
if (baseModelProvider.RawDataField != null)
{
return null;
}
baseModelProvider = baseModelProvider.BaseModelProvider;
}

var modifiers = FieldModifiers.Private;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,49 @@ public BuildBaseTypeOverridingModelProvider(InputModelType inputModel, CSharpTyp
protected override CSharpType? BuildBaseType() => _redirectedBaseType;
}

// Regression: custom code (such as an inheritable system base model) can produce a base
// ModelProvider chain that cycles back on itself. Base-model traversal during constructor,
// field, and raw-data discovery must terminate instead of recursing infinitely.
[Test]
public void BaseModelProviderCycleDoesNotRecurseInfinitely()
{
var inputA = InputFactory.Model(
"ModelA",
usage: InputModelTypeUsage.Input | InputModelTypeUsage.Json | InputModelTypeUsage.Output,
properties: [InputFactory.Property("aProp", InputPrimitiveType.String, isRequired: true)]);
var inputB = InputFactory.Model(
"ModelB",
usage: InputModelTypeUsage.Input | InputModelTypeUsage.Json | InputModelTypeUsage.Output,
properties: [InputFactory.Property("bProp", InputPrimitiveType.String, isRequired: true)]);
MockHelpers.LoadMockGenerator(inputModelTypes: [inputA, inputB]);

var modelA = new CyclicBaseModelProvider(inputA);
var modelB = new CyclicBaseModelProvider(inputB);

// Wire the base-model providers into a cycle: A -> B -> A.
modelA.CyclicBase = modelB;
modelB.CyclicBase = modelA;

Assert.AreSame(modelB, modelA.BaseModelProvider);
Assert.AreSame(modelA, modelB.BaseModelProvider);

// Each of these walks the base-model chain and previously stack-overflowed on a cycle.
Assert.DoesNotThrow(() => _ = modelA.FullConstructor);
Assert.DoesNotThrow(() => _ = modelA.Constructors);
Assert.DoesNotThrow(() => _ = modelA.Fields);
}

private sealed class CyclicBaseModelProvider : ModelProvider
{
public CyclicBaseModelProvider(InputModelType inputModel) : base(inputModel)
{
}

public ModelProvider? CyclicBase { get; set; }

protected override ModelProvider? BuildBaseModelProvider() => CyclicBase;
}

[Test]
public void BuildModelAsStruct()
{
Expand Down
Loading