Skip to content

Commit b8597b2

Browse files
committed
wrapped functions default export support
1 parent 0b611c7 commit b8597b2

2 files changed

Lines changed: 173 additions & 0 deletions

File tree

codeflash/languages/treesitter_utils.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ class ExportInfo:
9494
reexport_source: str | None # Module path for re-exports
9595
start_line: int
9696
end_line: int
97+
# Functions passed as arguments to wrapper calls in default exports
98+
# e.g., export default curry(traverseEntity) -> ["traverseEntity"]
99+
wrapped_default_args: list[str] | None = None
97100

98101

99102
@dataclass
@@ -707,6 +710,7 @@ def _extract_export_info(self, node: Node, source_bytes: bytes) -> ExportInfo |
707710
default_export: str | None = None
708711
is_reexport = False
709712
reexport_source: str | None = None
713+
wrapped_default_args: list[str] | None = None
710714

711715
# Check for re-export source (export { x } from './other')
712716
source_node = node.child_by_field_name("source")
@@ -726,6 +730,12 @@ def _extract_export_info(self, node: Node, source_bytes: bytes) -> ExportInfo |
726730
default_export = self.get_node_text(sibling, source_bytes)
727731
elif sibling.type in ("arrow_function", "function_expression", "object", "array"):
728732
default_export = "default"
733+
elif sibling.type == "call_expression":
734+
# Handle wrapped exports: export default curry(traverseEntity)
735+
# The default export is the result of the call, but we track
736+
# the wrapped function names for export checking
737+
default_export = "default"
738+
wrapped_default_args = self._extract_call_expression_identifiers(sibling, source_bytes)
729739
break
730740

731741
# Handle named exports: export { a, b as c }
@@ -773,8 +783,37 @@ def _extract_export_info(self, node: Node, source_bytes: bytes) -> ExportInfo |
773783
reexport_source=reexport_source,
774784
start_line=node.start_point[0] + 1,
775785
end_line=node.end_point[0] + 1,
786+
wrapped_default_args=wrapped_default_args,
776787
)
777788

789+
def _extract_call_expression_identifiers(self, node: Node, source_bytes: bytes) -> list[str]:
790+
"""Extract identifier names from arguments of a call expression.
791+
792+
For patterns like curry(traverseEntity) or compose(fn1, fn2), this extracts
793+
the function names passed as arguments: ["traverseEntity"] or ["fn1", "fn2"].
794+
795+
Args:
796+
node: A call_expression node.
797+
source_bytes: The source code as bytes.
798+
799+
Returns:
800+
List of identifier names found in the call arguments.
801+
802+
"""
803+
identifiers: list[str] = []
804+
805+
# Get the arguments node
806+
args_node = node.child_by_field_name("arguments")
807+
if args_node:
808+
for child in args_node.children:
809+
if child.type == "identifier":
810+
identifiers.append(self.get_node_text(child, source_bytes))
811+
# Also handle nested call expressions: compose(curry(fn))
812+
elif child.type == "call_expression":
813+
identifiers.extend(self._extract_call_expression_identifiers(child, source_bytes))
814+
815+
return identifiers
816+
778817
def _extract_commonjs_export(self, node: Node, source_bytes: bytes) -> ExportInfo | None:
779818
"""Extract export information from CommonJS module.exports or exports.* patterns.
780819
@@ -876,6 +915,7 @@ def is_function_exported(
876915
"""Check if a function is exported and get its export name.
877916
878917
For class methods, also checks if the containing class is exported.
918+
Also handles wrapped exports like: export default curry(traverseEntity)
879919
880920
Args:
881921
source: The source code to analyze.
@@ -901,6 +941,11 @@ def is_function_exported(
901941
if name == function_name:
902942
return (True, alias if alias else name)
903943

944+
# Check wrapped default exports: export default curry(traverseEntity)
945+
# The function is exported via wrapper, so it's accessible as "default"
946+
if export.wrapped_default_args and function_name in export.wrapped_default_args:
947+
return (True, "default")
948+
904949
# For class methods, check if the containing class is exported
905950
if class_name:
906951
for export in exports:

tests/test_languages/test_treesitter_utils.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -693,3 +693,131 @@ def test_non_exported_const_not_exported(self, ts_analyzer):
693693
is_public_exported, name = ts_analyzer.is_function_exported(code, "publicFunc")
694694
assert is_public_exported is True
695695
assert name == "publicFunc"
696+
697+
698+
class TestWrappedDefaultExports:
699+
"""Tests for wrapped default export pattern - Issue #9.
700+
701+
Handles patterns like:
702+
- export default curry(traverseEntity)
703+
- export default compose(fn1, fn2)
704+
- export default wrapper(myFunc)
705+
706+
These must be correctly recognized so the wrapped function is exportable.
707+
"""
708+
709+
@pytest.fixture
710+
def ts_analyzer(self):
711+
"""Create a TypeScript analyzer."""
712+
return TreeSitterAnalyzer(TreeSitterLanguage.TYPESCRIPT)
713+
714+
def test_curry_wrapped_export(self, ts_analyzer):
715+
"""Test export default curry(fn) pattern."""
716+
code = """import { curry } from 'lodash/fp';
717+
718+
const traverseEntity = async (visitor, options, entity) => {
719+
return entity;
720+
};
721+
722+
export default curry(traverseEntity);"""
723+
724+
# Check exports parsing
725+
exports = ts_analyzer.find_exports(code)
726+
assert len(exports) == 1
727+
assert exports[0].default_export == "default"
728+
assert exports[0].wrapped_default_args == ["traverseEntity"]
729+
730+
# Check is_function_exported
731+
is_exported, export_name = ts_analyzer.is_function_exported(code, "traverseEntity")
732+
assert is_exported is True
733+
assert export_name == "default"
734+
735+
def test_compose_wrapped_export(self, ts_analyzer):
736+
"""Test export default compose(fn1, fn2) pattern with multiple args."""
737+
code = """import { compose } from 'lodash/fp';
738+
739+
function validateInput(data) { return data; }
740+
function processData(data) { return data; }
741+
742+
export default compose(validateInput, processData);"""
743+
744+
exports = ts_analyzer.find_exports(code)
745+
assert len(exports) == 1
746+
assert exports[0].wrapped_default_args == ["validateInput", "processData"]
747+
748+
# Both functions should be recognized as exported
749+
is_exported1, _ = ts_analyzer.is_function_exported(code, "validateInput")
750+
is_exported2, _ = ts_analyzer.is_function_exported(code, "processData")
751+
assert is_exported1 is True
752+
assert is_exported2 is True
753+
754+
def test_nested_wrapper_export(self, ts_analyzer):
755+
"""Test nested wrapper: export default compose(curry(fn))."""
756+
code = """export default compose(curry(myFunc));"""
757+
758+
exports = ts_analyzer.find_exports(code)
759+
assert len(exports) == 1
760+
assert "myFunc" in exports[0].wrapped_default_args
761+
762+
is_exported, _ = ts_analyzer.is_function_exported(code, "myFunc")
763+
assert is_exported is True
764+
765+
def test_generic_wrapper_export(self, ts_analyzer):
766+
"""Test generic wrapper function."""
767+
code = """const myFunction = (x: number) => x * 2;
768+
769+
export default someWrapper(myFunction);"""
770+
771+
is_exported, export_name = ts_analyzer.is_function_exported(code, "myFunction")
772+
assert is_exported is True
773+
assert export_name == "default"
774+
775+
def test_non_wrapped_function_not_exported(self, ts_analyzer):
776+
"""Test that functions not in the wrapper call are not exported."""
777+
code = """const helper = (x: number) => x + 1;
778+
const main = (x: number) => helper(x) * 2;
779+
780+
export default curry(main);"""
781+
782+
# main is wrapped, so it's exported
783+
is_main_exported, _ = ts_analyzer.is_function_exported(code, "main")
784+
assert is_main_exported is True
785+
786+
# helper is NOT in the wrapper call, so not exported
787+
is_helper_exported, _ = ts_analyzer.is_function_exported(code, "helper")
788+
assert is_helper_exported is False
789+
790+
def test_direct_default_export_still_works(self, ts_analyzer):
791+
"""Test that direct default exports still work."""
792+
code = """function myFunc() { return 1; }
793+
export default myFunc;"""
794+
795+
is_exported, export_name = ts_analyzer.is_function_exported(code, "myFunc")
796+
assert is_exported is True
797+
assert export_name == "default"
798+
799+
def test_strapi_traverse_entity_pattern(self, ts_analyzer):
800+
"""Test the exact strapi pattern that was failing."""
801+
code = """import { curry } from 'lodash/fp';
802+
803+
const traverseEntity = async (visitor: Visitor, options: TraverseOptions, entity: Data) => {
804+
const { path = { raw: null }, schema, getModel } = options;
805+
// ... implementation
806+
return copy;
807+
};
808+
809+
const createVisitorUtils = ({ data }: { data: Data }) => ({
810+
remove(key: string) { delete data[key]; },
811+
set(key: string, value: Data) { data[key] = value; },
812+
});
813+
814+
export default curry(traverseEntity);"""
815+
816+
# traverseEntity should be recognized as exported
817+
is_exported, export_name = ts_analyzer.is_function_exported(code, "traverseEntity")
818+
assert is_exported is True
819+
assert export_name == "default"
820+
821+
# createVisitorUtils is NOT wrapped, so not exported via default
822+
is_utils_exported, _ = ts_analyzer.is_function_exported(code, "createVisitorUtils")
823+
assert is_utils_exported is False

0 commit comments

Comments
 (0)