Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ data class Requirement(
val isList: Boolean,
val isProperty: Boolean,
val propertyKey: String?,
val qualifier: QualifierValue?
val qualifier: QualifierValue?,
val isProvided: Boolean = false
) {
/**
* Whether this requirement must be validated (must have a matching provider).
Expand Down Expand Up @@ -165,7 +166,7 @@ class BindingRegistry {

// Skip @Provided types and framework-provided types (always available at runtime)
val reqFqName = req.typeKey.fqName?.asString() ?: req.typeKey.classId?.asFqNameString()
if (reqFqName != null && ProvidedTypeRegistry.isProvided(reqFqName)) {
if (req.isProvided || (reqFqName != null && ProvidedTypeRegistry.isProvided(reqFqName))) {
KoinPluginLogger.debug { " skip '${req.paramName}': ${req.typeKey.render()} (@Provided)" }
continue
}
Expand Down Expand Up @@ -400,7 +401,7 @@ class BindingRegistry {

// Skip @Provided types and framework-provided types (same as real validation path)
val reqFqName = req.typeKey.fqName?.asString() ?: req.typeKey.classId?.asFqNameString()
if (reqFqName != null && ProvidedTypeRegistry.isProvided(reqFqName)) continue
if (req.isProvided || (reqFqName != null && ProvidedTypeRegistry.isProvided(reqFqName))) continue
if (reqFqName != null && isWhitelistedType(reqFqName)) continue

val found = findProviderData(req, provided, consumerScopeFqName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.jetbrains.kotlin.ir.util.*
import org.jetbrains.kotlin.name.CallableId
import org.jetbrains.kotlin.name.Name
import org.koin.compiler.plugin.KoinPluginConstants
import org.koin.compiler.plugin.KoinAnnotationFqNames
import org.koin.compiler.plugin.KoinPluginLogger
import org.koin.compiler.plugin.ProvidedTypeRegistry
import org.koin.compiler.plugin.fir.KoinModuleFirGenerator
Expand Down Expand Up @@ -82,7 +83,7 @@ class CallSiteValidator(private val context: IrPluginContext) {

for (callSite in callSites) {
// Skip @Provided types
if (ProvidedTypeRegistry.isProvided(callSite.targetFqName)) {
if (ProvidedTypeRegistry.isProvided(callSite.targetFqName) || callSite.targetClass.hasAnnotation(KoinAnnotationFqNames.PROVIDED)) {
KoinPluginLogger.debug { "A4: Skip ${callSite.targetFqName} (@Provided)" }
continue
}
Expand Down Expand Up @@ -299,7 +300,7 @@ class CallSiteValidator(private val context: IrPluginContext) {
val targetFqName = targetClass.fqNameWhenAvailable?.asString() ?: continue

// Skip @Provided types
if (ProvidedTypeRegistry.isProvided(targetFqName)) {
if (ProvidedTypeRegistry.isProvided(targetFqName) || targetClass.hasAnnotation(KoinAnnotationFqNames.PROVIDED)) {
KoinPluginLogger.debug { "A4-deferred: Skip $targetFqName (@Provided)" }
continue
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ class DefinitionCallBuilder(
builder: DeclarationIrBuilder,
parentFunction: IrFunction
): IrExpression {
return lambdaBuilder.create(returnTypeClass, builder, parentFunction) { irBuilder, scopeParam, paramsParam ->
return lambdaBuilder.create(returnTypeClass.defaultType, builder, parentFunction) { irBuilder, scopeParam, paramsParam ->
irBuilder.irCallConstructor(constructor.symbol, emptyList()).apply {
constructor.valueParameters.forEachIndexed { index, param ->
val scopeGet = irBuilder.irGet(scopeParam)
Expand Down Expand Up @@ -571,7 +571,7 @@ class DefinitionCallBuilder(
return builder.irNull()
}

return lambdaBuilder.create(returnTypeClass, builder, parentFunction) { irBuilder, scopeParam, paramsParam ->
return lambdaBuilder.create(returnTypeClass.defaultType, builder, parentFunction) { irBuilder, scopeParam, paramsParam ->
irBuilder.irCall(targetFunction.symbol).apply {
dispatchReceiver = irBuilder.irGet(moduleInstanceReceiver)

Expand All @@ -596,7 +596,7 @@ class DefinitionCallBuilder(
builder: DeclarationIrBuilder,
parentFunction: IrFunction
): IrExpression {
return lambdaBuilder.create(returnTypeClass, builder, parentFunction) { irBuilder, scopeParam, paramsParam ->
return lambdaBuilder.create(returnTypeClass.defaultType, builder, parentFunction) { irBuilder, scopeParam, paramsParam ->
irBuilder.irCall(targetFunction.symbol).apply {
targetFunction.valueParameters.forEachIndexed { index, param ->
val scopeGet = irBuilder.irGet(scopeParam)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ class KoinArgumentGenerator(

private val propertyAnnotationFqName = KoinAnnotationFqNames.PROPERTY
private val lazyModeClass by lazy { context.referenceClass(ClassId.topLevel(FqName("kotlin.LazyThreadSafetyMode"))) }
private val listClass by lazy { context.referenceClass(ClassId.topLevel(FqName("kotlin.collections.List")))?.owner }
private val lazyClass by lazy { context.referenceClass(ClassId.topLevel(FqName("kotlin.Lazy")))?.owner }

override fun generateForParameter(
param: IrValueParameter,
Expand Down Expand Up @@ -296,28 +298,67 @@ class KoinArgumentGenerator(
builder: DeclarationIrBuilder
): IrExpression {
val scopeClass = (scopeReceiver.type.classifierOrNull?.owner as? IrClass)
if (scopeClass == null) {
KoinPluginLogger.debug { "Could not resolve scope class for getAll<${elementType.classFqName}>() call" }
return builder.irNull()
}
val listType = listClass?.typeWith(elementType)
if (scopeClass != null) {
// Prefer the member function when available, but fall back to the top-level extension
// so List<T> injection keeps working across Koin API shapes.
val getAllFunction = scopeClass.declarations
.filterIsInstance<IrSimpleFunction>()
.firstOrNull { function ->
function.name.asString() == "getAll" &&
function.typeParameters.size == 1 &&
function.valueParameters.isEmpty()
}
?: context.referenceFunctions(
CallableId(FqName("org.koin.core.scope"), Name.identifier("getAll"))
).map { it.owner }
.filterIsInstance<IrSimpleFunction>()
.firstOrNull { function ->
function.name.asString() == "getAll" &&
function.typeParameters.size == 1 &&
function.valueParameters.isEmpty() &&
(
function.dispatchReceiverParameter?.type?.classifierOrNull?.owner == scopeClass ||
function.extensionReceiverParameter?.type?.classifierOrNull?.owner == scopeClass
)
}

// Find getAll function in Scope
val getAllFunction = scopeClass.declarations
.filterIsInstance<IrSimpleFunction>()
.firstOrNull { function ->
function.name.asString() == "getAll" &&
function.typeParameters.size == 1
if (getAllFunction != null) {
return builder.irCall(getAllFunction.symbol).apply {
// Explicitly actualize return type to avoid leaking unbound function type parameter.
if (listType != null) type = listType
if (getAllFunction.dispatchReceiverParameter != null) {
dispatchReceiver = scopeReceiver
} else if (getAllFunction.extensionReceiverParameter != null) {
extensionReceiver = scopeReceiver
}
putTypeArgument(0, elementType)
}
}

if (getAllFunction != null) {
return builder.irCall(getAllFunction.symbol).apply {
dispatchReceiver = scopeReceiver
KoinPluginLogger.debug {
"Could not find getAll function on scope class ${scopeClass.name} for element type ${elementType.classFqName}; using emptyList() fallback"
}
} else {
KoinPluginLogger.debug {
"Could not resolve scope class for getAll<${elementType.classFqName}>() call; using emptyList() fallback"
}
}

val emptyListFunction = context.referenceFunctions(
CallableId(FqName("kotlin.collections"), Name.identifier("emptyList"))
).firstOrNull()?.owner

if (emptyListFunction != null) {
return builder.irCall(emptyListFunction.symbol).apply {
if (listType != null) type = listType
putTypeArgument(0, elementType)
}
}

// Fallback to empty list if getAll not found
KoinPluginLogger.debug { "Could not find getAll function on scope class ${scopeClass.name} for element type ${elementType.classFqName}" }
KoinPluginLogger.debug {
"Could not resolve emptyList<${elementType.classFqName}>() fallback for List injection"
}
return builder.irNull()
}

Expand Down Expand Up @@ -346,9 +387,12 @@ class KoinArgumentGenerator(
return builder.irNull()
}

val requestedType = type
return builder.irCall(getFunction.symbol).apply {
// Explicitly actualize return type to avoid leaking unbound function type parameter.
this.type = requestedType
dispatchReceiver = scopeReceiver
putTypeArgument(0, type)
putTypeArgument(0, requestedType)

getFunction.valueParameters.forEachIndexed { index, param ->
val paramTypeName = (param.type.classifierOrNull?.owner as? IrClass)?.name?.asString()
Expand Down Expand Up @@ -386,9 +430,12 @@ class KoinArgumentGenerator(
return builder.irNull()
}

val requestedType = type
return builder.irCall(getOrNullFunction.symbol).apply {
// Explicitly actualize return type to avoid leaking unbound function type parameter.
this.type = requestedType.makeNullable()
dispatchReceiver = scopeReceiver
putTypeArgument(0, type)
putTypeArgument(0, requestedType)

getOrNullFunction.valueParameters.forEachIndexed { index, param ->
val paramTypeName = (param.type.classifierOrNull?.owner as? IrClass)?.name?.asString()
Expand Down Expand Up @@ -430,9 +477,15 @@ class KoinArgumentGenerator(
?.filterIsInstance<IrEnumEntry>()
?.firstOrNull { it.name.asString() == "SYNCHRONIZED" }

val requestedType = type
return builder.irCall(injectFunction.symbol).apply {
// Explicitly actualize return type to avoid leaking unbound function type parameter.
val lazyType = lazyClass?.typeWith(requestedType)
if (lazyType != null) {
this.type = lazyType
}
dispatchReceiver = scopeReceiver
putTypeArgument(0, type)
putTypeArgument(0, requestedType)

injectFunction.valueParameters.forEachIndexed { index, param ->
val paramType = param.type
Expand Down Expand Up @@ -480,9 +533,11 @@ class KoinArgumentGenerator(
return builder.irNull()
}

val requestedType = type
return builder.irCall(getFunction.symbol).apply {
this.type = requestedType
dispatchReceiver = parametersHolderReceiver
putTypeArgument(0, type)
putTypeArgument(0, requestedType)
}
}

Expand All @@ -509,9 +564,11 @@ class KoinArgumentGenerator(
return builder.irNull()
}

val requestedType = type
return builder.irCall(getOrNullFunction.symbol).apply {
this.type = requestedType.makeNullable()
dispatchReceiver = parametersHolderReceiver
putTypeArgument(0, type)
putTypeArgument(0, requestedType)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,8 @@ class KoinDSLTransformer(
): IrExpression {
val typeArg = call.getTypeArgument(0) ?: return call
val targetClass = typeArg.classifierOrNull?.owner as? IrClass ?: return call
val erasedTargetType = erasedTypeForClass(targetClass)
val constructorTypeArguments = extractConstructorTypeArguments(typeArg, targetClass)
val constructor = targetClass.primaryConstructor
if (constructor == null) {
KoinPluginLogger.debug { "$functionName<${targetClass.name}>() skipped - no primary constructor" }
Expand Down Expand Up @@ -335,24 +337,24 @@ class KoinDSLTransformer(
// Build the transformed call
return builder.irCall(targetFunction.symbol).apply {
this.extensionReceiver = extensionReceiver
putTypeArgument(0, targetClass.defaultType)
putTypeArgument(0, typeArg)

// Arg 0: KClass<T>
val kClassClassOwner = kClassClass ?: return call
putValueArgument(0, IrClassReferenceImpl(
UNDEFINED_OFFSET, UNDEFINED_OFFSET,
kClassClassOwner.typeWith(targetClass.defaultType),
kClassClassOwner.typeWith(erasedTargetType),
targetClass.symbol,
targetClass.defaultType
erasedTargetType
))

// Arg 1: Qualifier? (for workers, always use class name as qualifier)
putValueArgument(1, qualifierExtractor.createQualifierCall(effectiveQualifier, builder) ?: builder.irNull())

// Arg 2: Definition lambda { T(get(), get(), ...) }
val parentFunc = currentFunction ?: return call
putValueArgument(2, lambdaBuilder.create(targetClass, builder, parentFunc) { lb, scopeParam, paramsParam ->
lb.irCallConstructor(constructor.symbol, emptyList()).apply {
putValueArgument(2, lambdaBuilder.create(typeArg, builder, parentFunc) { lb, scopeParam, paramsParam ->
lb.irCallConstructor(constructor.symbol, constructorTypeArguments).apply {
constructor.valueParameters.forEachIndexed { index, param ->
val scopeGet = lb.irGet(scopeParam)
val paramsGet = lb.irGet(paramsParam)
Expand Down Expand Up @@ -469,18 +471,19 @@ class KoinDSLTransformer(
KoinPluginLogger.user { "Applying qualifier ${qualifier.debugString()} to $functionName { create(::${returnClass.name}) }" }

val builder = DeclarationIrBuilder(context, call.symbol, call.startOffset, call.endOffset)
val erasedReturnType = erasedTypeForClass(returnClass)

return builder.irCall(targetFunction.symbol).apply {
this.extensionReceiver = receiver
putTypeArgument(0, returnClass.defaultType)
putTypeArgument(0, erasedReturnType)

// Arg 0: KClass<T>
val kClassClassOwner = kClassClass ?: return call
putValueArgument(0, IrClassReferenceImpl(
UNDEFINED_OFFSET, UNDEFINED_OFFSET,
kClassClassOwner.typeWith(returnClass.defaultType),
kClassClassOwner.typeWith(erasedReturnType),
returnClass.symbol,
returnClass.defaultType
erasedReturnType
))

// Arg 1: Qualifier
Expand Down Expand Up @@ -535,6 +538,34 @@ class KoinDSLTransformer(
return expr === targetCall
}

/**
* Runtime KClass registration in Koin is erased, so generic type arguments are replaced
* with Any? when emitting class literals (Foo::class).
*/
private fun erasedTypeForClass(targetClass: IrClass): IrType {
if (targetClass.typeParameters.isEmpty()) return targetClass.defaultType
val erasedArguments = Array(targetClass.typeParameters.size) { context.irBuiltIns.anyNType }
return targetClass.typeWith(*erasedArguments)
}

/**
* Preserve concrete type arguments for constructor calls when the user wrote single<Foo<Bar>>().
*/
private fun extractConstructorTypeArguments(typeArg: IrType, targetClass: IrClass): List<IrType> {
val simpleType = typeArg as? IrSimpleType ?: return emptyList()
if (simpleType.classifierOrNull?.owner != targetClass) return emptyList()
if (simpleType.arguments.size != targetClass.typeParameters.size) return emptyList()

val concreteArguments = simpleType.arguments
.mapNotNull { (it as? IrTypeProjection)?.type }

return if (concreteArguments.size == targetClass.typeParameters.size) {
concreteArguments
} else {
emptyList()
}
}

private fun findTargetFunction(functionName: Name, receiverClassName: String): IrSimpleFunction? {
// Map stub function name to target function name (e.g., single -> buildSingle)
val targetName = targetFunctionNames[functionName] ?: return null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import org.jetbrains.kotlin.ir.expressions.IrStatementOrigin
import org.jetbrains.kotlin.ir.expressions.impl.IrFunctionExpressionImpl
import org.jetbrains.kotlin.ir.symbols.impl.IrSimpleFunctionSymbolImpl
import org.jetbrains.kotlin.ir.symbols.impl.IrValueParameterSymbolImpl
import org.jetbrains.kotlin.ir.types.IrType
import org.jetbrains.kotlin.ir.types.typeWith
import org.jetbrains.kotlin.ir.util.defaultType
import org.jetbrains.kotlin.name.ClassId
Expand Down Expand Up @@ -75,14 +76,14 @@ class LambdaBuilder(
* - scopeParam: The Scope extension receiver parameter
* - paramsParam: The ParametersHolder value parameter
*
* @param returnTypeClass The return type of the lambda
* @param returnType The return type of the lambda
* @param builder The outer declaration builder
* @param parentFunction The parent function containing this lambda
* @param bodyBuilder Callback to create the body expression
* @return The lambda expression, or irNull() if required classes are not found
*/
fun create(
returnTypeClass: IrClass,
returnType: IrType,
builder: DeclarationIrBuilder,
parentFunction: IrFunction,
bodyBuilder: (
Expand All @@ -108,7 +109,7 @@ class LambdaBuilder(
visibility = DescriptorVisibilities.LOCAL,
isInline = false,
isExpect = false,
returnType = returnTypeClass.defaultType,
returnType = returnType,
modality = Modality.FINAL,
symbol = IrSimpleFunctionSymbolImpl(),
isTailrec = false,
Expand Down Expand Up @@ -172,7 +173,7 @@ class LambdaBuilder(
val lambdaType = func2Class.typeWith(
scopeClassLocal.defaultType,
paramsHolderClass.defaultType,
returnTypeClass.defaultType
returnType
)

return IrFunctionExpressionImpl(
Expand Down
Loading