Skip to content

Commit bf60db4

Browse files
committed
expand jinja
1 parent ef703f6 commit bf60db4

4 files changed

Lines changed: 185 additions & 9 deletions

File tree

sdks/python/apache_beam/yaml/main.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,13 @@ def _build_pipeline_yaml_from_argv(argv):
235235
argv = _preparse_jinja_flags(argv)
236236
known_args, pipeline_args = _parse_arguments(argv)
237237
pipeline_template = _pipeline_spec_from_args(known_args)
238+
239+
search_paths = []
240+
if known_args.yaml_pipeline_file:
241+
search_paths.append(FileSystems.split(known_args.yaml_pipeline_file)[0])
242+
238243
pipeline_yaml = yaml_transform.expand_jinja(
239-
pipeline_template, known_args.jinja_variables or {})
244+
pipeline_template, known_args.jinja_variables or {}, search_paths)
240245
return known_args, pipeline_args, pipeline_template, pipeline_yaml
241246

242247

sdks/python/apache_beam/yaml/yaml_provider.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -469,10 +469,14 @@ def _with_extra_dependencies(self, dependencies: Iterable[str]):
469469

470470
@ExternalProvider.register_provider_type('yaml')
471471
class YamlProvider(Provider):
472-
def __init__(self, transforms: Mapping[str, Mapping[str, Any]]):
472+
def __init__(
473+
self,
474+
transforms: Mapping[str, Mapping[str, Any]],
475+
provider_base_path: Optional[str] = None):
473476
if not isinstance(transforms, dict):
474477
raise ValueError('Transform mapping must be a dict.')
475478
self._transforms = transforms
479+
self._provider_base_path = provider_base_path
476480

477481
def available(self):
478482
return True
@@ -524,7 +528,10 @@ def create_transform(
524528
else:
525529
body_str = yaml.safe_dump(SafeLineLoader.strip_metadata(body))
526530
# Now re-parse resolved templatization.
527-
body = yaml.load(expand_jinja(body_str, args), Loader=SafeLineLoader)
531+
search_paths = [FileSystems.split(self._provider_base_path)[0]
532+
] if self._provider_base_path else []
533+
body = yaml.load(
534+
expand_jinja(body_str, args, search_paths), Loader=SafeLineLoader)
528535
if (body.get('type') == 'chain' and 'input' not in body and
529536
spec.get('requires_inputs', True)):
530537
body['input'] = 'input'

sdks/python/apache_beam/yaml/yaml_transform.py

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1391,19 +1391,65 @@ def validate_transform_references(spec):
13911391
return spec
13921392

13931393

1394+
def strip_leading_comments(source: str) -> str:
1395+
lines = source.splitlines(keepends=True)
1396+
stripped_lines = []
1397+
in_leading_comments = True
1398+
for line in lines:
1399+
stripped_line = line.lstrip()
1400+
if in_leading_comments:
1401+
if stripped_line.startswith('#') or not stripped_line:
1402+
continue
1403+
else:
1404+
in_leading_comments = False
1405+
stripped_lines.append(line)
1406+
return "".join(stripped_lines)
1407+
1408+
13941409
class _BeamFileIOLoader(jinja2.BaseLoader):
1410+
def __init__(self, search_paths=()):
1411+
self.search_paths = list(search_paths)
1412+
13951413
def get_source(self, environment, path):
1396-
with FileSystems.open(path) as fin:
1397-
source = fin.read().decode()
1398-
return source, path, lambda: True
1414+
candidates = []
1415+
if FileSystems.get_scheme(path) is not None or path.startswith('/'):
1416+
candidates.append(path)
1417+
else:
1418+
candidates.append(path)
1419+
for search_path in self.search_paths:
1420+
candidates.append(FileSystems.join(search_path, path))
1421+
1422+
for candidate in candidates:
1423+
try:
1424+
if FileSystems.exists(candidate):
1425+
with FileSystems.open(candidate) as fin:
1426+
source = fin.read().decode()
1427+
return strip_leading_comments(source), candidate, lambda: True
1428+
except Exception:
1429+
pass
1430+
1431+
raise jinja2.TemplateNotFound(path)
13991432

14001433

14011434
def expand_jinja(
1402-
jinja_template: str, jinja_variables: Mapping[str, Any]) -> str:
1435+
jinja_template: str,
1436+
jinja_variables: Mapping[str, Any],
1437+
search_paths: Iterable[str] = ()) -> str:
1438+
import apache_beam
1439+
beam_root_dir = os.path.dirname(
1440+
os.path.dirname(os.path.abspath(apache_beam.__file__)))
1441+
1442+
all_search_paths = list(search_paths)
1443+
if beam_root_dir not in all_search_paths:
1444+
all_search_paths.append(beam_root_dir)
1445+
if '.' not in all_search_paths:
1446+
all_search_paths.append('.')
1447+
14031448
return ( # keep formatting
14041449
jinja2.Environment(
1405-
undefined=jinja2.StrictUndefined, loader=_BeamFileIOLoader())
1406-
.from_string(jinja_template)
1450+
undefined=jinja2.StrictUndefined,
1451+
loader=_BeamFileIOLoader(all_search_paths))
1452+
.from_string(strip_leading_comments(jinja_template))
14071453
.render(datetime=datetime, **jinja_variables))
14081454

14091455

sdks/python/apache_beam/yaml/yaml_transform_test.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,16 @@
2222
import shutil
2323
import tempfile
2424
import unittest
25+
import yaml
2526

2627
import apache_beam as beam
2728
from apache_beam.testing.util import assert_that
2829
from apache_beam.testing.util import equal_to
2930
from apache_beam.utils import python_callable
3031
from apache_beam.yaml import yaml_provider
32+
from apache_beam.yaml.yaml_transform import SafeLineLoader
3133
from apache_beam.yaml.yaml_transform import YamlTransform
34+
from apache_beam.yaml.yaml_transform import expand_jinja
3235

3336
try:
3437
import jsonschema
@@ -1467,6 +1470,65 @@ def test_must_consume_error_output(self):
14671470
''',
14681471
providers=merged_providers)
14691472

1473+
def test_provider_with_jinja_imports(self):
1474+
# Create a macro file in the same temp directory as the provider
1475+
macro_path = os.path.join(self.temp_dir, 'my_macros.yaml')
1476+
with open(macro_path, 'w') as f:
1477+
f.write(
1478+
"""
1479+
{%- macro power_expr(var, n) -%}
1480+
{{ var }} ** {{ n }}
1481+
{%- endmacro -%}
1482+
""")
1483+
1484+
# Create a provider that imports and uses the macro
1485+
templated_provider_path = os.path.join(
1486+
self.temp_dir, 'templated_provider.yaml')
1487+
with open(templated_provider_path, 'w') as f:
1488+
f.write(
1489+
"""
1490+
- type: yaml
1491+
transforms:
1492+
CustomPower:
1493+
config_schema:
1494+
properties:
1495+
n: {type: integer}
1496+
body: |
1497+
type: MapToFields
1498+
config:
1499+
language: python
1500+
append: true
1501+
fields:
1502+
power: "{% import 'my_macros.yaml' as m %}{{ m.power_expr('element', n) }}"
1503+
""")
1504+
1505+
loaded_providers = yaml_provider.load_providers(templated_provider_path)
1506+
test_providers = yaml_provider.InlineProvider(TEST_PROVIDERS)
1507+
merged_providers = yaml_provider.merge_providers(
1508+
loaded_providers, [test_providers])
1509+
1510+
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
1511+
pickle_library='cloudpickle')) as p:
1512+
results = p | YamlTransform(
1513+
'''
1514+
type: composite
1515+
transforms:
1516+
- type: Create
1517+
config:
1518+
elements: [2, 3]
1519+
- type: CustomPower
1520+
input: Create
1521+
config:
1522+
n: 3
1523+
output: CustomPower
1524+
''',
1525+
providers=merged_providers)
1526+
1527+
assert_that(
1528+
results,
1529+
equal_to(
1530+
[beam.Row(element=2, power=8), beam.Row(element=3, power=27)]))
1531+
14701532

14711533
@beam.transforms.ptransform.annotate_yaml
14721534
class LinearTransform(beam.PTransform):
@@ -1481,6 +1543,62 @@ def expand(self, pcoll):
14811543
return pcoll | beam.Map(lambda x: a * x.element + b)
14821544

14831545

1546+
class TestYamlExpandJinja(unittest.TestCase):
1547+
def setUp(self):
1548+
self.temp_dir = tempfile.mkdtemp()
1549+
# Create a macro file with leading comments (license header)
1550+
self.macro_path = os.path.join(self.temp_dir, 'my_macros.yaml')
1551+
with open(self.macro_path, 'w') as f:
1552+
f.write(
1553+
"""# coding=utf-8
1554+
# Licensed to the Apache Software Foundation...
1555+
# Some leading comment line
1556+
1557+
{%- macro add_n(val, n) -%}
1558+
{{ val }} + {{ n }}
1559+
{%- endmacro -%}
1560+
""")
1561+
1562+
# Create a pipeline template that includes/imports the macro
1563+
self.pipeline_path = os.path.join(self.temp_dir, 'my_pipeline.yaml')
1564+
with open(self.pipeline_path, 'w') as f:
1565+
f.write(
1566+
"""# coding=utf-8
1567+
# Licensed to the Apache Software Foundation...
1568+
1569+
{% import 'my_macros.yaml' as macros %}
1570+
type: composite
1571+
transforms:
1572+
- type: Create
1573+
config:
1574+
elements: [1, 2, 3]
1575+
- type: MapToFields
1576+
config:
1577+
language: python
1578+
fields:
1579+
result: {{ macros.add_n('element', 10) }}
1580+
""")
1581+
1582+
def tearDown(self):
1583+
shutil.rmtree(self.temp_dir)
1584+
1585+
def test_expand_jinja_with_leading_comments_and_imports(self):
1586+
# Read the pipeline template
1587+
with open(self.pipeline_path, 'r') as f:
1588+
template_content = f.read()
1589+
1590+
# Expand the jinja using our temp_dir as a search path
1591+
expanded = expand_jinja(template_content, {}, [self.temp_dir])
1592+
1593+
# Parse the expanded YAML
1594+
parsed = yaml.load(expanded, Loader=SafeLineLoader)
1595+
1596+
# Verify the comment-stripping and import resolution was successful
1597+
self.assertEqual(parsed['type'], 'composite')
1598+
self.assertEqual(
1599+
parsed['transforms'][1]['config']['fields']['result'], 'element + 10')
1600+
1601+
14841602
if __name__ == '__main__':
14851603
logging.getLogger().setLevel(logging.INFO)
14861604
unittest.main()

0 commit comments

Comments
 (0)