Skip to content

Commit 2a81596

Browse files
authored
Add unittests for output_methods/file.py (#978)
* Add tests for file output. Also: * Add docstrings for file.py * Drop empty columsn before doing dataframe concatenations to prevent panda warnings. * Fix the "update" logic in FileOutput.out() when a single matching row exists. * Remove a defunct comment and debug print. * Remove unused imports and variables
1 parent 66cc372 commit 2a81596

2 files changed

Lines changed: 259 additions & 7 deletions

File tree

codecarbon/output_methods/file.py

Lines changed: 59 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,29 @@
1313
class FileOutput(BaseOutput):
1414
"""
1515
Saves experiment artifacts to a file
16+
17+
Attributes:
18+
output_file_name: str, name of file to write to.
19+
output_dir: str, path to directory to write to.
20+
save_file_path: str, path to file to write to.
21+
on_csv_write: str, "append" or "update", whether or not to append or overwrite a file if it exists.
1622
"""
1723

1824
def __init__(
1925
self, output_file_name: str, output_dir: str, on_csv_write: str = "append"
2026
):
27+
"""
28+
Initialize the FileOutput object.
29+
30+
Args:
31+
output_file_name: name of file to write to.
32+
output_dir: path to directory to write to.
33+
on_csv_write: "append" or "update", whether or not to append or overwrite a file if it exists
34+
35+
Raises:
36+
ValueError: If the on_csv_write value is invalid.
37+
OSError: If the output directory does not exist.
38+
"""
2139
if on_csv_write not in {"append", "update"}:
2240
raise ValueError(
2341
f"Unknown `on_csv_write` value: {on_csv_write}"
@@ -33,7 +51,16 @@ def __init__(
3351
f"Emissions data (if any) will be saved to file {os.path.abspath(self.save_file_path)}"
3452
)
3553

36-
def has_valid_headers(self, data: EmissionsData):
54+
def has_valid_headers(self, data: EmissionsData) -> bool:
55+
"""
56+
Checks self.save_file_path has headers matching those from passed data.
57+
58+
Args:
59+
data: EmissionsData object with valid headers.
60+
61+
Returns:
62+
True if the file has valid headers, False otherwise.
63+
"""
3764
with open(self.save_file_path) as csv_file:
3865
csv_reader = csv.DictReader(csv_file)
3966
csv_entries_list = list(csv_reader)
@@ -44,11 +71,21 @@ def has_valid_headers(self, data: EmissionsData):
4471
list_of_column_names = list(dict_from_csv.keys())
4572
return list(data.values.keys()) == list_of_column_names
4673

47-
def out(self, total: EmissionsData, delta: EmissionsData):
74+
def out(self, total: EmissionsData, _: EmissionsData):
4875
"""
49-
Save the emissions data to a CSV file.
50-
If the file already exists, append the new data to it.
51-
param `delta` is not used in this method.
76+
Save the emissions data from a whole run to a CSV file.
77+
78+
* If the file does not exist, then create it.
79+
* If the file already exists but has invalid headers, then back it up and replace with new data.
80+
* If the file already exists and has valid headers:
81+
* If it has no rows with a matching run ID, append the new data.
82+
* If it has one row with a matching run ID, then replace that row with the new data.
83+
* If it has > one row with a matching run ID, append the new data
84+
85+
Args:
86+
total: data to save.
87+
88+
5289
"""
5390
file_exists: bool = os.path.isfile(self.save_file_path)
5491
if file_exists and not self.has_valid_headers(total):
@@ -60,6 +97,10 @@ def out(self, total: EmissionsData, delta: EmissionsData):
6097
df = new_df
6198
elif self.on_csv_write == "append":
6299
df = pd.read_csv(self.save_file_path)
100+
# Filter out empty or all-NA columns, to avoid warnings from Pandas,
101+
# see https://github.com/pandas-dev/pandas/issues/55928
102+
df = df.dropna(axis=1, how="all")
103+
new_df = new_df.dropna(axis=1, how="all")
63104
df = pd.concat([df, new_df])
64105
else:
65106
df = pd.read_csv(self.save_file_path)
@@ -74,13 +115,22 @@ def out(self, total: EmissionsData, delta: EmissionsData):
74115
)
75116
df = pd.concat([df, new_df])
76117
else:
77-
df.at[df.run_id == total.run_id, total.values.keys()] = (
78-
total.values.values()
118+
update_values = {}
119+
for col, val in dict(total.values).items():
120+
# Explicitly cast new values to prevent warnings about incompatible dtypes.
121+
update_values[col] = df[col].dtype.type(val)
122+
df.loc[df.run_id == total.run_id, update_values.keys()] = (
123+
update_values.values()
79124
)
80125

81126
df.to_csv(self.save_file_path, index=False)
82127

83128
def task_out(self, data: List[TaskEmissionsData], experiment_name: str):
129+
"""
130+
Save the emissions data from a single task in an experiment run to a CSV file.
131+
132+
Does not attempt to backup existing files or prevent ovewritting them.
133+
"""
84134
run_id = data[0].run_id
85135
save_task_file_path = os.path.join(
86136
self.output_dir, "emissions_" + experiment_name + "_" + run_id + ".csv"
@@ -90,6 +140,8 @@ def task_out(self, data: List[TaskEmissionsData], experiment_name: str):
90140
[dict(data_point.values) for data_point in data]
91141
)
92142
# Filter out empty or all-NA columns, to avoid warnings from Pandas
143+
# see https://github.com/pandas-dev/pandas/issues/55928
144+
df = df.dropna(axis=1, how="all")
93145
new_df = new_df.dropna(axis=1, how="all")
94146
df = pd.concat([df, new_df], ignore_index=True)
95147
df.to_csv(save_task_file_path, index=False)

tests/output_methods/test_file.py

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
import os
2+
import shutil
3+
import tempfile
4+
import unittest
5+
from unittest.mock import MagicMock, patch
6+
7+
import pandas as pd
8+
9+
from codecarbon.output_methods.emissions_data import EmissionsData, TaskEmissionsData
10+
from codecarbon.output_methods.file import FileOutput
11+
12+
13+
class TestFileOutput(unittest.TestCase):
14+
def setUp(self):
15+
self.temp_dir = tempfile.mkdtemp()
16+
self.emissions_data = EmissionsData(
17+
timestamp="2023-01-01T00:00:00",
18+
project_name="test_project",
19+
run_id="test_run_id",
20+
experiment_id="test_experiment_id",
21+
duration=10,
22+
emissions=0.5,
23+
emissions_rate=0.05,
24+
cpu_power=20,
25+
gpu_power=30,
26+
ram_power=5,
27+
cpu_energy=200,
28+
gpu_energy=300,
29+
ram_energy=50,
30+
energy_consumed=550,
31+
water_consumed=0.1,
32+
country_name="Testland",
33+
country_iso_code="TS",
34+
region="Test Region",
35+
cloud_provider="Test Cloud",
36+
cloud_region="test-cloud-1",
37+
os="TestOS",
38+
python_version="3.8",
39+
codecarbon_version="2.0",
40+
cpu_count=4,
41+
cpu_model="Test CPU",
42+
gpu_count=1,
43+
gpu_model="Test GPU",
44+
longitude=0,
45+
latitude=0,
46+
ram_total_size=16,
47+
tracking_mode="machine",
48+
on_cloud="true",
49+
pue=1.5,
50+
wue=0.5,
51+
)
52+
53+
def tearDown(self):
54+
shutil.rmtree(self.temp_dir)
55+
56+
def test_file_output_initialization(self):
57+
FileOutput("test.csv", self.temp_dir)
58+
59+
def test_file_output_initialization_invalid_csv_write_mode(self):
60+
with self.assertRaises(ValueError):
61+
FileOutput("test.csv", self.temp_dir, on_csv_write="invalid_option")
62+
63+
def test_file_output_initialization_invalid_dir(self):
64+
with self.assertRaises(OSError):
65+
FileOutput("test.csv", "/non/existent/dir")
66+
67+
def test_has_valid_headers_success(self):
68+
file_output = FileOutput("test.csv", self.temp_dir)
69+
file_output.out(self.emissions_data, MagicMock())
70+
71+
self.assertTrue(file_output.has_valid_headers(self.emissions_data))
72+
73+
def test_has_valid_headers_failure(self):
74+
file_output = FileOutput("test.csv", self.temp_dir)
75+
file_output.out(self.emissions_data, MagicMock())
76+
77+
df = pd.read_csv(os.path.join(self.temp_dir, "test.csv"))
78+
df.rename(columns={"wue": "new_header"}, inplace=True)
79+
df.to_csv(os.path.join(self.temp_dir, "test.csv"), index=False)
80+
81+
self.assertFalse(file_output.has_valid_headers(self.emissions_data))
82+
83+
@patch("codecarbon.output_methods.file.FileOutput.has_valid_headers")
84+
def test_file_output_out_file_exists_invalid_headers(self, mock_has_valid_headers):
85+
file_output = FileOutput("test.csv", self.temp_dir, on_csv_write="append")
86+
file_output.out(self.emissions_data, MagicMock())
87+
88+
mock_has_valid_headers.return_value = False
89+
file_output.out(self.emissions_data, MagicMock())
90+
91+
df = pd.read_csv(os.path.join(self.temp_dir, "test.csv.bak"))
92+
self.assertEqual(len(df), 1)
93+
df = pd.read_csv(os.path.join(self.temp_dir, "test.csv"))
94+
self.assertEqual(len(df), 1)
95+
96+
def test_file_output_out_update_no_file_exists(self):
97+
file_output = FileOutput("test.csv", self.temp_dir, on_csv_write="update")
98+
file_output.out(self.emissions_data, MagicMock())
99+
100+
df = pd.read_csv(os.path.join(self.temp_dir, "test.csv"))
101+
self.assertEqual(len(df), 1)
102+
103+
def test_file_output_out_append_no_file_exists(self):
104+
file_output = FileOutput("test.csv", self.temp_dir, on_csv_write="append")
105+
file_output.out(self.emissions_data, MagicMock())
106+
107+
df = pd.read_csv(os.path.join(self.temp_dir, "test.csv"))
108+
self.assertEqual(len(df), 1)
109+
110+
def test_file_output_out_append_file_exists(self):
111+
file_output = FileOutput("test.csv", self.temp_dir, on_csv_write="append")
112+
file_output.out(self.emissions_data, MagicMock())
113+
file_output.out(self.emissions_data, MagicMock())
114+
115+
df = pd.read_csv(os.path.join(self.temp_dir, "test.csv"))
116+
self.assertEqual(len(df), 2)
117+
118+
def test_file_output_out_update_file_exists_no_matching_row(self):
119+
file_output = FileOutput("test.csv", self.temp_dir, on_csv_write="update")
120+
file_output.out(self.emissions_data, MagicMock())
121+
122+
updated_emissions_data = self.emissions_data
123+
updated_emissions_data.run_id = "new_test_run_id"
124+
file_output.out(updated_emissions_data, MagicMock())
125+
126+
df = pd.read_csv(os.path.join(self.temp_dir, "test.csv"))
127+
self.assertEqual(len(df), 2)
128+
129+
def test_file_output_out_update_file_exists_multiple_matching_rows(self):
130+
file_output = FileOutput("test.csv", self.temp_dir, on_csv_write="update")
131+
file_output.out(self.emissions_data, MagicMock())
132+
133+
# Manually add a duplicate row to simulate the condition
134+
df = pd.read_csv(os.path.join(self.temp_dir, "test.csv"))
135+
df = pd.concat([df, df])
136+
df.to_csv(os.path.join(self.temp_dir, "test.csv"), index=False)
137+
138+
file_output.out(self.emissions_data, MagicMock())
139+
140+
df = pd.read_csv(os.path.join(self.temp_dir, "test.csv"))
141+
self.assertEqual(len(df), 3)
142+
143+
def test_file_output_out_update_file_exists_one_matchingrows(self):
144+
file_output = FileOutput("test.csv", self.temp_dir, on_csv_write="update")
145+
file_output.out(self.emissions_data, MagicMock())
146+
df = pd.read_csv(os.path.join(self.temp_dir, "test.csv"))
147+
self.assertEqual(df["cpu_power"].iloc[0], 20)
148+
149+
new_data = self.emissions_data
150+
new_data.cpu_power = 2
151+
file_output.out(new_data, MagicMock())
152+
df = pd.read_csv(os.path.join(self.temp_dir, "test.csv"))
153+
self.assertEqual(df["cpu_power"].iloc[0], 2)
154+
155+
def test_file_output_task_out(self):
156+
task_emissions_data = [
157+
TaskEmissionsData(
158+
task_name="test_task",
159+
timestamp="2023-01-01T00:00:00",
160+
project_name="test_project",
161+
run_id="test_run_id",
162+
duration=10,
163+
emissions=0.5,
164+
emissions_rate=0.05,
165+
cpu_power=20,
166+
gpu_power=30,
167+
ram_power=5,
168+
cpu_energy=200,
169+
gpu_energy=300,
170+
ram_energy=50,
171+
energy_consumed=550,
172+
water_consumed=0.1,
173+
country_name="Testland",
174+
country_iso_code="TS",
175+
region="Test Region",
176+
cloud_provider="Test Cloud",
177+
cloud_region="test-cloud-1",
178+
os="TestOS",
179+
python_version="3.8",
180+
codecarbon_version="2.0",
181+
cpu_count=4,
182+
cpu_model="Test CPU",
183+
gpu_count=1,
184+
gpu_model="Test GPU",
185+
longitude=0,
186+
latitude=0,
187+
ram_total_size=16,
188+
tracking_mode="machine",
189+
on_cloud="true",
190+
)
191+
]
192+
file_output = FileOutput("test.csv", self.temp_dir)
193+
file_output.task_out(task_emissions_data, "test_experiment")
194+
195+
expected_file = os.path.join(
196+
self.temp_dir, "emissions_test_experiment_test_run_id.csv"
197+
)
198+
self.assertTrue(os.path.exists(expected_file))
199+
df = pd.read_csv(expected_file)
200+
self.assertEqual(len(df), 1)

0 commit comments

Comments
 (0)