Skip to content

Commit d32798d

Browse files
Rework import scripts
1 parent 1e8ccbb commit d32798d

2 files changed

Lines changed: 69 additions & 67 deletions

File tree

Lines changed: 50 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
import json
12
import xml.etree.ElementTree as ET
23

34
from django.core.management.base import BaseCommand
45
from django.db import transaction
6+
from tqdm import tqdm
57

6-
from problem.models import FracasPremise, FracasProblem
7-
from problem.utils import progress
8+
from langpro_annotator.logger import logger
9+
from problem.services import get_fracas_problems
10+
from problem.models import Problem
811

912

1013
class Command(BaseCommand):
@@ -22,7 +25,18 @@ def handle(self, *args, **options):
2225
fracas_path = options["fracas_path"]
2326
self.import_fracas_problems(fracas_path)
2427

25-
def annotate_section_subsections(self, tree: ET.ElementTree) -> None:
28+
@staticmethod
29+
def _text_from_element(element: ET.Element) -> str:
30+
"""
31+
Extracts stripped text from an XML element, returning an empty string if the element is None or has no text.
32+
"""
33+
return element.text.strip() if element is not None and element.text else ""
34+
35+
@staticmethod
36+
def _annotate_section_subsections(tree: ET.ElementTree) -> None:
37+
"""
38+
Annotates each problem in the XML tree with its corresponding section, subsection, and subsubsection.
39+
"""
2640
current_section = None
2741
current_subsection = None
2842
current_subsubsection = None
@@ -50,76 +64,63 @@ def annotate_section_subsections(self, tree: ET.ElementTree) -> None:
5064
element.set("subsubsection", current_subsubsection)
5165

5266
def import_fracas_problems(self, fracas_path: str) -> None:
53-
# Parse the XML file
5467
tree = ET.parse(fracas_path)
55-
self.annotate_section_subsections(tree)
68+
self._annotate_section_subsections(tree)
5669
root = tree.getroot()
57-
5870
all_problems = root.findall("problem")
59-
total = len(all_problems)
60-
n = 1
6171

72+
created = 0
6273
skipped = 0
6374

64-
def text_from_element(element: ET.Element) -> str:
65-
"""
66-
Extracts stripped text from an XML element, returning an empty string if the element is None or has no text.
67-
"""
68-
return element.text.strip() if element is not None and element.text else ""
75+
existing_fracas_problems = get_fracas_problems()
76+
existing_fracas_ids = {p.fracas_id for p in existing_fracas_problems}
6977

70-
for problem in root.findall("problem"):
78+
for problem in tqdm(all_problems, desc="Importing FraCaS problems"):
7179
problem_id = problem.get("id")
72-
7380
if problem_id is None:
7481
raise ValueError(
7582
"Problem ID is missing in the XML file for problem: {}".format(
7683
problem
7784
)
7885
)
7986

80-
progress(n, total)
81-
n += 1
82-
83-
if FracasProblem.objects.filter(fracas_id=problem_id).exists():
87+
if int(problem_id) in existing_fracas_ids:
8488
skipped += 1
8589
continue
8690

87-
question = text_from_element(problem.find("q"))
88-
hypothesis = text_from_element(problem.find("h"))
89-
answer = text_from_element(problem.find("a"))
90-
note = text_from_element(problem.find("note"))
91+
question = self._text_from_element(problem.find("q"))
92+
hypothesis = self._text_from_element(problem.find("h"))
93+
answer = self._text_from_element(problem.find("a"))
94+
note = self._text_from_element(problem.find("note"))
9195

9296
section = problem.get("section")
9397
subsection = problem.get("subsection")
9498
fracas_answer = problem.get("fracas_answer")
9599
fracas_nonstandard = problem.get("fracas_nonstandard", False) == "true"
96100

101+
premise_nodes = problem.findall("p")
102+
premises = [node.text.strip() for node in premise_nodes if node.text]
103+
97104
with transaction.atomic():
98-
fracas_problem = FracasProblem.objects.create(
99-
fracas_id=int(problem_id),
100-
question=question,
101-
hypothesis=hypothesis,
102-
answer=answer,
103-
fracas_answer=fracas_answer,
104-
fracas_non_standard=fracas_nonstandard,
105-
note=note,
106-
section_name=section,
107-
subsection_name=subsection,
105+
Problem.objects.create(
106+
type=Problem.ProblemType.FRACAS,
107+
content=json.dumps(
108+
{
109+
"fracas_id": int(problem_id),
110+
"question": question,
111+
"hypothesis": hypothesis,
112+
"answer": answer,
113+
"fracas_answer": fracas_answer,
114+
"fracas_non_standard": fracas_nonstandard,
115+
"note": note,
116+
"section_name": section,
117+
"subsection_name": subsection,
118+
"premises": premises,
119+
}
120+
),
108121
)
122+
created += 1
109123

110-
premises = problem.findall("p")
111-
for premise in premises:
112-
premise_index = premise.get("idx", None)
113-
if premise_index is None:
114-
raise ValueError(
115-
"Premise index is missing in the XML file for problem: {}".format(
116-
problem
117-
)
118-
)
119-
FracasPremise.objects.create(
120-
fracas_problem=fracas_problem,
121-
premise_index=int(premise_index),
122-
premise=premise.text.strip() if premise.text else "",
123-
)
124-
125-
print(f"FraCaS problems import complete! Total: {total} | Skipped: {skipped}")
124+
logger.info(
125+
f"FraCaS problems import complete! Total: {created} | Skipped: {skipped}"
126+
)
Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
import csv
2+
import json
23

34
from django.core.management.base import BaseCommand
4-
from problem.utils import progress
5-
from problem.models import SickProblem
5+
from tqdm import tqdm
6+
7+
from langpro_annotator.logger import logger
8+
from problem.models import Problem
9+
from problem.services import get_sick_problems
610

711

812
class Command(BaseCommand):
@@ -25,30 +29,27 @@ def import_sick_problems(self, sick_path: str) -> None:
2529
Import SICK problems from SICK.txt (a TSV file) and enter them into the database.
2630
"""
2731

28-
print("Importing SICK problems...")
29-
3032
skipped = 0
33+
created = 0
34+
35+
existing_sick_problems = get_sick_problems()
36+
existing_pair_ids = {p.pair_id for p in existing_sick_problems}
3137

3238
with open(sick_path, "r", encoding="utf-8") as file:
3339
reader = csv.DictReader(file, delimiter="\t")
3440
problem_list = list(reader)
3541

36-
total = len(problem_list)
37-
n = 1
38-
39-
for row in problem_list:
40-
progress(n, total)
41-
n += 1
42-
if SickProblem.objects.filter(pair_id=row["pair_ID"]).exists():
42+
for problem in tqdm(problem_list, desc="Importing SICK problems"):
43+
if problem["pair_ID"] in existing_pair_ids:
4344
skipped += 1
4445
continue
4546

46-
SickProblem.objects.create(
47-
pair_id=row["pair_ID"],
48-
sentence_one=row["sentence_A"],
49-
sentence_two=row["sentence_B"],
50-
entailment_label=row["entailment_label"],
51-
relatedness_score=row["relatedness_score"],
47+
created += 1
48+
Problem.objects.create(
49+
type=Problem.ProblemType.SICK,
50+
content=json.dumps(problem),
5251
)
5352

54-
print(f"SICK problems import complete! Total: {total} | Skipped: {skipped}")
53+
logger.info(
54+
f"SICK problems import complete! Created: {created} | Skipped: {skipped}"
55+
)

0 commit comments

Comments
 (0)