2222import shutil
2323import tempfile
2424import unittest
25+ import yaml
2526
2627import apache_beam as beam
2728from apache_beam .testing .util import assert_that
2829from apache_beam .testing .util import equal_to
2930from apache_beam .utils import python_callable
3031from apache_beam .yaml import yaml_provider
32+ from apache_beam .yaml .yaml_transform import SafeLineLoader
3133from apache_beam .yaml .yaml_transform import YamlTransform
34+ from apache_beam .yaml .yaml_transform import expand_jinja
3235
3336try :
3437 import jsonschema
@@ -1467,6 +1470,65 @@ def test_must_consume_error_output(self):
14671470 ''' ,
14681471 providers = merged_providers )
14691472
1473+ def test_provider_with_jinja_imports (self ):
1474+ # Create a macro file in the same temp directory as the provider
1475+ macro_path = os .path .join (self .temp_dir , 'my_macros.yaml' )
1476+ with open (macro_path , 'w' ) as f :
1477+ f .write (
1478+ """
1479+ {%- macro power_expr(var, n) -%}
1480+ {{ var }} ** {{ n }}
1481+ {%- endmacro -%}
1482+ """ )
1483+
1484+ # Create a provider that imports and uses the macro
1485+ templated_provider_path = os .path .join (
1486+ self .temp_dir , 'templated_provider.yaml' )
1487+ with open (templated_provider_path , 'w' ) as f :
1488+ f .write (
1489+ """
1490+ - type: yaml
1491+ transforms:
1492+ CustomPower:
1493+ config_schema:
1494+ properties:
1495+ n: {type: integer}
1496+ body: |
1497+ type: MapToFields
1498+ config:
1499+ language: python
1500+ append: true
1501+ fields:
1502+ power: "{% import 'my_macros.yaml' as m %}{{ m.power_expr('element', n) }}"
1503+ """ )
1504+
1505+ loaded_providers = yaml_provider .load_providers (templated_provider_path )
1506+ test_providers = yaml_provider .InlineProvider (TEST_PROVIDERS )
1507+ merged_providers = yaml_provider .merge_providers (
1508+ loaded_providers , [test_providers ])
1509+
1510+ with beam .Pipeline (options = beam .options .pipeline_options .PipelineOptions (
1511+ pickle_library = 'cloudpickle' )) as p :
1512+ results = p | YamlTransform (
1513+ '''
1514+ type: composite
1515+ transforms:
1516+ - type: Create
1517+ config:
1518+ elements: [2, 3]
1519+ - type: CustomPower
1520+ input: Create
1521+ config:
1522+ n: 3
1523+ output: CustomPower
1524+ ''' ,
1525+ providers = merged_providers )
1526+
1527+ assert_that (
1528+ results ,
1529+ equal_to (
1530+ [beam .Row (element = 2 , power = 8 ), beam .Row (element = 3 , power = 27 )]))
1531+
14701532
14711533@beam .transforms .ptransform .annotate_yaml
14721534class LinearTransform (beam .PTransform ):
@@ -1481,6 +1543,62 @@ def expand(self, pcoll):
14811543 return pcoll | beam .Map (lambda x : a * x .element + b )
14821544
14831545
1546+ class TestYamlExpandJinja (unittest .TestCase ):
1547+ def setUp (self ):
1548+ self .temp_dir = tempfile .mkdtemp ()
1549+ # Create a macro file with leading comments (license header)
1550+ self .macro_path = os .path .join (self .temp_dir , 'my_macros.yaml' )
1551+ with open (self .macro_path , 'w' ) as f :
1552+ f .write (
1553+ """# coding=utf-8
1554+ # Licensed to the Apache Software Foundation...
1555+ # Some leading comment line
1556+
1557+ {%- macro add_n(val, n) -%}
1558+ {{ val }} + {{ n }}
1559+ {%- endmacro -%}
1560+ """ )
1561+
1562+ # Create a pipeline template that includes/imports the macro
1563+ self .pipeline_path = os .path .join (self .temp_dir , 'my_pipeline.yaml' )
1564+ with open (self .pipeline_path , 'w' ) as f :
1565+ f .write (
1566+ """# coding=utf-8
1567+ # Licensed to the Apache Software Foundation...
1568+
1569+ {% import 'my_macros.yaml' as macros %}
1570+ type: composite
1571+ transforms:
1572+ - type: Create
1573+ config:
1574+ elements: [1, 2, 3]
1575+ - type: MapToFields
1576+ config:
1577+ language: python
1578+ fields:
1579+ result: {{ macros.add_n('element', 10) }}
1580+ """ )
1581+
1582+ def tearDown (self ):
1583+ shutil .rmtree (self .temp_dir )
1584+
1585+ def test_expand_jinja_with_leading_comments_and_imports (self ):
1586+ # Read the pipeline template
1587+ with open (self .pipeline_path , 'r' ) as f :
1588+ template_content = f .read ()
1589+
1590+ # Expand the jinja using our temp_dir as a search path
1591+ expanded = expand_jinja (template_content , {}, [self .temp_dir ])
1592+
1593+ # Parse the expanded YAML
1594+ parsed = yaml .load (expanded , Loader = SafeLineLoader )
1595+
1596+ # Verify the comment-stripping and import resolution was successful
1597+ self .assertEqual (parsed ['type' ], 'composite' )
1598+ self .assertEqual (
1599+ parsed ['transforms' ][1 ]['config' ]['fields' ]['result' ], 'element + 10' )
1600+
1601+
14841602if __name__ == '__main__' :
14851603 logging .getLogger ().setLevel (logging .INFO )
14861604 unittest .main ()
0 commit comments