Skip to content

Commit cf63d92

Browse files
authored
feat(m-decomp): upgraded pipeline and added README, examples, and fixed module issues (#676)
* upd: clean over long or result files * add: validation code generation * add: validation code icl * add: prompt init * add: validation decision * add: decomp jinja * refact: pipeline and primary stages * refact: module logging * fea: validation report * fea: validation report icl * fea: cli script with config * add: examples * add: examples * add: README doc * pre-commit: add test attribute * upd: type annotations * fix: add constraint type annotation * upd: pre-commit format * upd: pre-commit type annotations * upd: pre-commit format * add: m_decompose tests * add: constraint retry * upd: constraint * upd: constraint * upd: same logmode * fix: a missed parse * pre-commit format * add: multi request support * clean * fix: type clean * upd: input file arg * fea: constraint retry * test * fix * fix * fix * clean: final result * add: decompose tests * add: decompose tests * clean: pre-commit format * clean: pre-commit formating on decompose tests
1 parent 7501093 commit cf63d92

55 files changed

Lines changed: 1811 additions & 3525 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

cli/decompose/decompose.py

Lines changed: 141 additions & 142 deletions
Large diffs are not rendered by default.

cli/decompose/logging.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import logging
2+
import sys
3+
from enum import StrEnum
4+
5+
6+
class LogMode(StrEnum):
7+
demo = "demo"
8+
debug = "debug"
9+
10+
11+
_CONFIGURED = False
12+
13+
14+
def configure_logging(log_mode: LogMode = LogMode.demo) -> None:
15+
global _CONFIGURED
16+
17+
level = logging.DEBUG if log_mode == LogMode.debug else logging.INFO
18+
19+
root_logger = logging.getLogger()
20+
21+
if not _CONFIGURED:
22+
handler = logging.StreamHandler(sys.stdout)
23+
handler.setFormatter(
24+
logging.Formatter("[%(levelname)s] %(name)s | %(message)s")
25+
)
26+
root_logger.handlers.clear()
27+
root_logger.addHandler(handler)
28+
_CONFIGURED = True
29+
30+
root_logger.setLevel(level)
31+
32+
logging.getLogger("m_decompose").setLevel(level)
33+
logging.getLogger("mellea").setLevel(level)
34+
35+
36+
def get_logger(name: str) -> logging.Logger:
37+
return logging.getLogger(name)
38+
39+
40+
def log_section(logger: logging.Logger, title: str) -> None:
41+
logger.info("")
42+
logger.info("=" * 72)
43+
logger.info(title)
44+
logger.info("=" * 72)

cli/decompose/m_decomp_result_v1.py.jinja2

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,14 @@ import os
44
import textwrap
55

66
import mellea
7+
{%- if "code" in identified_constraints | map(attribute="val_strategy") %}
8+
from mellea.stdlib.requirement import req
9+
{% for c in identified_constraints %}
10+
{%- if c.val_fn %}
11+
from validations.{{ c.val_fn_name }} import validate_input as {{ c.val_fn_name }}
12+
{%- endif %}
13+
{%- endfor %}
14+
{%- endif %}
715

816
m = mellea.start_session()
917
{%- if user_inputs %}
@@ -30,7 +38,14 @@ except KeyError as e:
3038
{%- if item.constraints %}
3139
requirements=[
3240
{%- for c in item.constraints %}
41+
{%- if c.val_fn %}
42+
req(
43+
{{ c.constraint | tojson}},
44+
validation_fn={{ c.val_fn_name }},
45+
),
46+
{%- else %}
3347
{{ c.constraint | tojson}},
48+
{%- endif %}
3449
{%- endfor %}
3550
],
3651
{%- else %}
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
{% if user_inputs -%}
2+
import os
3+
{% endif -%}
4+
import textwrap
5+
6+
import mellea
7+
{%- if "code" in identified_constraints | map(attribute="val_strategy") %}
8+
from mellea.stdlib.requirement import req
9+
{% for c in identified_constraints %}
10+
{%- if c.val_fn %}
11+
from validations.{{ c.val_fn_name }} import validate_input as {{ c.val_fn_name }}
12+
{%- endif %}
13+
{%- endfor %}
14+
{%- endif %}
15+
16+
m = mellea.start_session()
17+
{%- if user_inputs %}
18+
19+
20+
# User Input Variables
21+
try:
22+
{%- for var in user_inputs %}
23+
{{ var | lower }} = os.environ["{{ var | upper }}"]
24+
{%- endfor %}
25+
except KeyError as e:
26+
print(f"ERROR: One or more required environment variables are not set; {e}")
27+
exit(1)
28+
{%- endif %}
29+
{%- for item in subtasks %}
30+
31+
32+
{{ item.tag | lower }}_gnrl = textwrap.dedent(
33+
R"""
34+
{{ item.general_instructions | trim | indent(width=4, first=False) }}
35+
""".strip()
36+
)
37+
{{ item.tag | lower }} = m.instruct(
38+
{%- if not item.input_vars_required %}
39+
{{ item.subtask[3:] | trim | tojson }},
40+
{%- else %}
41+
textwrap.dedent(
42+
R"""
43+
{{ item.subtask[3:] | trim }}
44+
45+
Here are the input variables and their content:
46+
{%- for var in item.input_vars_required %}
47+
48+
- {{ var | upper }} = {{ "{{" }}{{ var | upper }}{{ "}}" }}
49+
{%- endfor %}
50+
""".strip()
51+
),
52+
{%- endif %}
53+
{%- if item.constraints %}
54+
requirements=[
55+
{%- for c in item.constraints %}
56+
{%- if c.val_fn %}
57+
req(
58+
{{ c.constraint | tojson}},
59+
validation_fn={{ c.val_fn_name }},
60+
),
61+
{%- else %}
62+
{{ c.constraint | tojson}},
63+
{%- endif %}
64+
{%- endfor %}
65+
],
66+
{%- else %}
67+
requirements=None,
68+
{%- endif %}
69+
{%- if item.input_vars_required %}
70+
user_variables={
71+
{%- for var in item.input_vars_required %}
72+
{{ var | upper | tojson }}: {{ var | lower }},
73+
{%- endfor %}
74+
},
75+
{%- endif %}
76+
grounding_context={
77+
"GENERAL_INSTRUCTIONS": {{ item.tag | lower }}_gnrl,
78+
{%- for var in item.depends_on %}
79+
{{ var | upper | tojson }}: {{ var | lower }}.value,
80+
{%- endfor %}
81+
},
82+
)
83+
assert {{ item.tag | lower }}.value is not None, 'ERROR: task "{{ item.tag | lower }}" execution failed'
84+
{%- if loop.last %}
85+
86+
87+
final_answer = {{ item.tag | lower }}.value
88+
89+
print(final_answer)
90+
{%- endif -%}
91+
{%- endfor -%}

0 commit comments

Comments
 (0)