Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,14 @@ private IReadOnlyList<ModelProvider> BuildDerivedModels()
return existingProvider;
}

// Try to find the type in the customization compilation (excluding referenced assemblies)
// Try to find the type in the customization compilation. Referenced assemblies are
// included so custom bases from framework or external packages are represented by
// normal symbol-backed providers.
var baseTypeProvider = CodeModelGenerator.Instance.SourceInputModel.FindForTypeInCustomization(
baseType.Namespace,
baseType.Name,
baseType.DeclaringType?.Name);
baseType.DeclaringType?.Name,
includeReferencedAssemblies: true);

if (baseTypeProvider != null)
{
Expand All @@ -155,8 +158,8 @@ private IReadOnlyList<ModelProvider> BuildDerivedModels()
return baseTypeProvider;
}

// If we couldn't find the type symbol (e.g., type is from a referenced assembly),
// create a SystemObjectTypeProvider that represents the external type
// If we couldn't find the type symbol, create a SystemObjectTypeProvider that
// represents the external type without member metadata.
var systemObjectTypeProvider = new SystemObjectTypeProvider(baseType);
// Cache it in CSharpTypeMap for future lookups
CodeModelGenerator.Instance.TypeFactory.CSharpTypeMap[baseType] = systemObjectTypeProvider;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ internal sealed class NamedTypeSymbolProvider : TypeProvider
{
private INamedTypeSymbol _namedTypeSymbol;
private readonly Compilation _compilation;
private TypeProvider? _baseTypeProvider;

public NamedTypeSymbolProvider(INamedTypeSymbol namedTypeSymbol, Compilation compilation)
{
Expand All @@ -41,21 +42,36 @@ public NamedTypeSymbolProvider(INamedTypeSymbol namedTypeSymbol, Compilation com
protected override IReadOnlyList<AttributeStatement> BuildAttributes()
=> [.._namedTypeSymbol.GetAttributes().Select(a => new AttributeStatement(a))];

internal override TypeProvider? BaseTypeProvider => _baseTypeProvider ??= BuildBaseTypeProvider();

protected override CSharpType? BuildBaseType()
{
if (_namedTypeSymbol.BaseType == null
|| _namedTypeSymbol.BaseType.SpecialType == SpecialType.System_Object
|| _namedTypeSymbol.BaseType.SpecialType == SpecialType.System_ValueType
|| _namedTypeSymbol.BaseType.SpecialType == SpecialType.System_Array
|| _namedTypeSymbol.BaseType.SpecialType == SpecialType.System_Enum
|| TypeSymbolExtensions.ContainsTypeAsArgument(_namedTypeSymbol.BaseType, _namedTypeSymbol))
if (ShouldSkipBaseType(_namedTypeSymbol.BaseType))
{
return null;
}

return _namedTypeSymbol.BaseType.GetCSharpType();
return _namedTypeSymbol.BaseType!.GetCSharpType();
Comment thread
JoshLove-msft marked this conversation as resolved.
}

private TypeProvider? BuildBaseTypeProvider()
{
if (ShouldSkipBaseType(_namedTypeSymbol.BaseType))
{
return null;
}

return new NamedTypeSymbolProvider(_namedTypeSymbol.BaseType!, _compilation);
}

private bool ShouldSkipBaseType(INamedTypeSymbol? baseType)
=> baseType == null
|| baseType.SpecialType == SpecialType.System_Object
|| baseType.SpecialType == SpecialType.System_ValueType
|| baseType.SpecialType == SpecialType.System_Array
|| baseType.SpecialType == SpecialType.System_Enum
|| TypeSymbolExtensions.ContainsTypeAsArgument(baseType, _namedTypeSymbol);

protected override TypeSignatureModifiers BuildDeclarationModifiers()
{
var declaredModifiers = GetAccessModifiers(_namedTypeSymbol.DeclaredAccessibility);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,23 @@ private IReadOnlyList<PropertyProvider> BuildAllCustomProperties()
var allCustomProperties = CustomCodeView?.Properties != null
? new List<PropertyProvider>(CustomCodeView.Properties)
: [];
var baseTypeCustomCodeView = BaseTypeProvider?.CustomCodeView;
var baseTypeProvider = BaseTypeProvider;
var includeBaseProviderMembers = CustomCodeView?.BaseType != null;
var visited = new HashSet<TypeProvider>();

// add all custom properties from base types
while (baseTypeCustomCodeView != null)
while (baseTypeProvider != null && visited.Add(baseTypeProvider))
{
allCustomProperties.AddRange(baseTypeCustomCodeView.Properties);
baseTypeCustomCodeView = baseTypeCustomCodeView.BaseTypeProvider?.CustomCodeView;
if (includeBaseProviderMembers)
{
allCustomProperties.AddRange(baseTypeProvider.Properties);
}

if (baseTypeProvider.CustomCodeView is { } customCodeView)
{
allCustomProperties.AddRange(customCodeView.Properties);
}
baseTypeProvider = baseTypeProvider.BaseTypeProvider;
}

return allCustomProperties;
Expand All @@ -81,13 +91,23 @@ private IReadOnlyList<FieldProvider> BuildAllCustomFields()
var allCustomFields = CustomCodeView?.Fields != null
? new List<FieldProvider>(CustomCodeView.Fields)
: [];
var baseTypeCustomCodeView = BaseTypeProvider?.CustomCodeView;
var baseTypeProvider = BaseTypeProvider;
var includeBaseProviderMembers = CustomCodeView?.BaseType != null;
var visited = new HashSet<TypeProvider>();

// add all custom fields from base types
while (baseTypeCustomCodeView != null)
while (baseTypeProvider != null && visited.Add(baseTypeProvider))
{
allCustomFields.AddRange(baseTypeCustomCodeView.Fields);
baseTypeCustomCodeView = baseTypeCustomCodeView.BaseTypeProvider?.CustomCodeView;
if (includeBaseProviderMembers)
{
allCustomFields.AddRange(baseTypeProvider.Fields);
}

if (baseTypeProvider.CustomCodeView is { } customCodeView)
{
allCustomFields.AddRange(customCodeView.Fields);
}
baseTypeProvider = baseTypeProvider.BaseTypeProvider;
}

return allCustomFields;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ private IReadOnlyDictionary<string, INamedTypeSymbol> PopulateNameMap()
return nameMap;
}

public TypeProvider? FindForTypeInCustomization(string ns, string name, string? declaringTypeName = null)
public TypeProvider? FindForTypeInCustomization(string ns, string name, string? declaringTypeName = null, bool includeReferencedAssemblies = false)
{
return FindTypeInCustomization(Customization, ns, name, false, declaringTypeName);
return FindTypeInCustomization(Customization, ns, name, includeReferencedAssemblies, declaringTypeName);
}

public TypeProvider? FindForTypeInLastContract(string ns, string name, string? declaringTypeName = null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1765,10 +1765,100 @@ public async Task CanCustomizeBaseModelToSystemType()
// The BaseModelProvider should be null since the base is not a generated model
Assert.IsNull(modelProvider.BaseModelProvider);

// System types from referenced assemblies are NOT found by FindForTypeInCustomization
// (which only searches the customization assembly, not references), so they use SystemObjectTypeProvider
Assert.IsInstanceOf<SystemObjectTypeProvider>(modelProvider.BaseTypeProvider,
"System.Exception is from a referenced assembly and should use SystemObjectTypeProvider");
// System types from referenced assemblies are found in the customization compilation
// so inherited members can be represented by normal property providers.
Assert.IsInstanceOf<NamedTypeSymbolProvider>(modelProvider.BaseTypeProvider,
"System.Exception is from a referenced assembly and should use NamedTypeSymbolProvider");
}

[Test]
public async Task CanCustomizeSpecBaseModelToSystemType()
{
// This verifies that a custom partial base type wins even when the input model
// has a TypeSpec base model. Otherwise the generated partial would keep the
// TypeSpec base and conflict with the custom partial declaration.
var specBaseModel = InputFactory.Model(
"specBaseModel",
properties: [InputFactory.Property("specBaseProp", InputPrimitiveType.String)],
usage: InputModelTypeUsage.Json);
var childModel = InputFactory.Model(
"mockInputModel",
properties: [
InputFactory.Property("message", InputPrimitiveType.String),
InputFactory.Property("childProp", InputPrimitiveType.String),
],
baseModel: specBaseModel,
usage: InputModelTypeUsage.Json);

var mockGenerator = await MockHelpers.LoadMockGeneratorAsync(
inputModelTypes: [childModel, specBaseModel],
compilation: async () => await Helpers.GetCompilationFromDirectoryAsync());

var modelProvider = mockGenerator.Object.OutputLibrary.TypeProviders.Single(t => t.Name == "MockInputModel") as ModelProvider;

Assert.IsNotNull(modelProvider);
Assert.IsNotNull(modelProvider!.BaseType);
Assert.AreEqual("Exception", modelProvider.BaseType!.Name);
Assert.AreEqual("System", modelProvider.BaseType!.Namespace);
Assert.IsNull(modelProvider.BaseModelProvider, "The TypeSpec base model should not be used when custom code declares a system base type.");
Assert.IsInstanceOf<NamedTypeSymbolProvider>(modelProvider.BaseTypeProvider);
Assert.That(modelProvider.Properties.Select(p => p.Name), Does.Not.Contain("Message"));
Assert.That(modelProvider.Properties.Select(p => p.Name), Does.Contain("ChildProp"));

var modelContent = new TypeProviderWriter(modelProvider).Write().Content;
Assert.That(modelContent, Does.Contain("public partial class MockInputModel : global::System.Exception"));
Assert.That(modelContent, Does.Not.Contain("SpecBaseModel"));
Assert.That(modelContent, Does.Not.Contain("public string Message"));
}

[Test]
public async Task CanCustomizeSpecBaseModelToSystemObjectModelProvider()
{
// This verifies the generator-specific system model path used by management-plane
// generators: a custom base type can resolve to a SystemObjectModelProvider in
// CSharpTypeMap, and generated members inherited from that mapped provider are filtered.
var specBaseModel = InputFactory.Model(
"specBaseModel",
properties: [InputFactory.Property("specBaseProp", InputPrimitiveType.String)],
usage: InputModelTypeUsage.Json);
var childModel = InputFactory.Model(
"mockInputModel",
properties: [
InputFactory.Property("id", InputPrimitiveType.String),
InputFactory.Property("name", InputPrimitiveType.String),
InputFactory.Property("childProp", InputPrimitiveType.String),
],
baseModel: specBaseModel,
usage: InputModelTypeUsage.Json);
var systemInputModel = InputFactory.Model(
"ResourceData",
properties: [
InputFactory.Property("id", InputPrimitiveType.String),
InputFactory.Property("name", InputPrimitiveType.String),
],
usage: InputModelTypeUsage.Json);

await MockHelpers.LoadMockGeneratorAsync(
inputModelTypes: [childModel, specBaseModel, systemInputModel],
compilation: async () => await Helpers.GetCompilationFromDirectoryAsync());

var customBaseType = CreateSystemCSharpType("ResourceData", "TestFramework");
CodeModelGenerator.Instance.TypeFactory.CSharpTypeMap[customBaseType] = new SystemObjectModelProvider(customBaseType, systemInputModel);

var modelProvider = new ModelProvider(childModel);

Assert.IsNotNull(modelProvider.BaseType);
Assert.AreEqual("ResourceData", modelProvider.BaseType!.Name);
Assert.AreEqual("TestFramework", modelProvider.BaseType!.Namespace);
Assert.IsInstanceOf<SystemObjectModelProvider>(modelProvider.BaseTypeProvider);
Assert.That(modelProvider.Properties.Select(p => p.Name), Does.Not.Contain("Id"));
Assert.That(modelProvider.Properties.Select(p => p.Name), Does.Not.Contain("Name"));
Assert.That(modelProvider.Properties.Select(p => p.Name), Does.Contain("ChildProp"));

var modelContent = new TypeProviderWriter(modelProvider).Write().Content;
Assert.That(modelContent, Does.Contain("public partial class MockInputModel : global::TestFramework.ResourceData"));
Assert.That(modelContent, Does.Not.Contain("public string Id"));
Assert.That(modelContent, Does.Not.Contain("public string Name"));
}

[Test]
Expand Down Expand Up @@ -1847,6 +1937,10 @@ await MockHelpers.LoadMockGeneratorAsync(

private const string AttributeNamespace = TestCustomCodeAttributeDefinition.AttributeNamespace;

private static CSharpType CreateSystemCSharpType(string name, string ns)
=> new(name, ns, isValueType: false, isNullable: false, declaringType: null,
args: Array.Empty<CSharpType>(), isPublic: true, isStruct: false);

private class TestNameVisitor : NameVisitor
{
public TypeProvider? InvokeVisit(TypeProvider type)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#nullable disable

namespace TestFramework
{
public class ResourceData
{
}
}

namespace Sample.Models
{
public partial class MockInputModel : TestFramework.ResourceData
{
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#nullable disable

using System;

namespace Sample.Models
{
public partial class MockInputModel : Exception
{
}
}
Loading