Skip to content

Commit 87017d6

Browse files
committed
added generic type definitions, tying to get overloads working
1 parent 12b2880 commit 87017d6

1 file changed

Lines changed: 58 additions & 2 deletions

File tree

stubgen/dotnet_stubs.py

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -854,11 +854,15 @@ def _should_include_method(self, method, type_info: TypeInfo, docs: Dict[str, Do
854854
if '.' in method_name and not method_name.startswith('op_'):
855855
return False
856856

857+
if method_name.startswith('op_'):
858+
return True
859+
857860
# Since we're only getting public methods now, include all remaining methods
858-
return True
861+
return not method.IsSpecialName
859862

860863
def _extract_methods(self, net_type, type_info: TypeInfo):
861864
"""Extract method information"""
865+
logger.debug(f"Considering method: {method.Name}, IsSpecialName: {method.IsSpecialName}, IsPublic: {method.IsPublic}, IsStatic: {method.IsStatic}")
862866
try:
863867
# Get all methods (public only, including inherited methods for complete API surface)
864868
binding_flags = BindingFlags.Public | BindingFlags.Instance | BindingFlags.Static
@@ -1211,10 +1215,15 @@ def add_type_reference(self, type_str: str, available_types: Dict[str, TypeInfo]
12111215
return
12121216

12131217
# Handle typing imports
1214-
if type_str in {'Any', 'List', 'Dict', 'Optional', 'Union', 'Generic', 'Callable', 'Tuple', 'Set'}:
1218+
if type_str in {'Any', 'List', 'Dict', 'Optional', 'Union', 'Generic', 'Callable', 'Tuple', 'Set', 'TypeVar'}:
12151219
self.typing_imports.add(type_str)
12161220
return
12171221

1222+
# Handle generic type parameters (T, T1, T2, etc.)
1223+
if re.match(r'^T\d*$', type_str):
1224+
self.typing_imports.add('TypeVar')
1225+
return
1226+
12181227
# Handle datetime imports
12191228
if type_str in {'datetime', 'timedelta'}:
12201229
self.datetime_imports.add(type_str)
@@ -1321,6 +1330,30 @@ def generate_import_lines(self) -> List[str]:
13211330
class PythonStubGenerator:
13221331
"""Generates Python stub files from type information"""
13231332

1333+
OPERATOR_TO_MAGIC_METHOD = {
1334+
'op_Addition': '__add__',
1335+
'op_Subtraction': '__sub__',
1336+
'op_Multiply': '__mul__',
1337+
'op_Division': '__truediv__',
1338+
'op_Modulus': '__mod__',
1339+
'op_Equality': '__eq__',
1340+
'op_Inequality': '__ne__',
1341+
'op_LessThan': '__lt__',
1342+
'op_GreaterThan': '__gt__',
1343+
'op_LessThanOrEqual': '__le__',
1344+
'op_GreaterThanOrEqual': '__ge__',
1345+
'op_BitwiseAnd': '__and__',
1346+
'op_BitwiseOr': '__or__',
1347+
'op_ExclusiveOr': '__xor__',
1348+
'op_UnaryNegation': '__neg__',
1349+
'op_UnaryPlus': '__pos__',
1350+
'op_Increment': '__inc__',
1351+
'op_Decrement': '__dec__',
1352+
'op_LogicalNot': '__not__',
1353+
'op_LeftShift': '__lshift__',
1354+
'op_RightShift': '__rshift__',
1355+
}
1356+
13241357
def __init__(self, type_infos: Dict[str, TypeInfo], docs: Dict[str, DocumentationInfo]):
13251358
self.type_infos = type_infos
13261359
self.docs = docs
@@ -1461,6 +1494,12 @@ def _generate_imports_for_namespace(self, namespace: str, types: List[TypeInfo])
14611494

14621495
# Comprehensive typing imports needed for stubs
14631496
typing_imports = {'Any', 'List', 'Dict', 'Optional', 'Union', 'Generic', 'overload'}
1497+
1498+
# Check if any types in this namespace are generic and need TypeVar
1499+
has_generic_types = any(type_info.is_generic and type_info.generic_parameters for type_info in types)
1500+
if has_generic_types:
1501+
typing_imports.add('TypeVar')
1502+
14641503
datetime_imports = {'datetime'}
14651504
system_imports = {} # namespace -> set of types
14661505
cross_namespace_imports = {} # namespace -> set of types
@@ -1570,6 +1609,22 @@ def _generate_module_content(self, namespace: str, types: List[TypeInfo]) -> str
15701609
lines.extend(import_lines)
15711610
lines.append("")
15721611

1612+
# Collect all generic type parameters used in this module
1613+
type_vars_needed = set()
1614+
for type_info in types:
1615+
if type_info.is_generic and type_info.generic_parameters:
1616+
for i, param in enumerate(type_info.generic_parameters):
1617+
if i == 0:
1618+
type_vars_needed.add('T')
1619+
else:
1620+
type_vars_needed.add(f'T{i}')
1621+
1622+
# Generate TypeVar declarations if needed
1623+
if type_vars_needed:
1624+
for type_var in sorted(type_vars_needed):
1625+
lines.append(f"{type_var} = TypeVar('{type_var}')")
1626+
lines.append("")
1627+
15731628
# Generate type stubs
15741629
for type_info in sorted(types, key=lambda t: t.simple_name):
15751630
if type_info.is_class:
@@ -1779,6 +1834,7 @@ def _generate_method_stub(self, type_info: TypeInfo, method_info: MethodInfo, is
17791834
lines.append(" @staticmethod")
17801835

17811836
# Sanitize method name
1837+
method_name = self.OPERATOR_TO_MAGIC_METHOD.get(method_info.name, method_info.name)
17821838
method_name = self._sanitize_identifier(method_info.name)
17831839

17841840
# Build parameter list

0 commit comments

Comments
 (0)