Skip to content

Commit 6f92e9d

Browse files
adolkhanadolkhan
andauthored
Deep Memory recall fix (#2698)
fixing deep memory recall --------- Co-authored-by: adolkhan <adilkhan.sarsen@alumni.nu.edu.kz>
1 parent 8eb512c commit 6f92e9d

File tree

3 files changed

+54
-8
lines changed

3 files changed

+54
-8
lines changed

deeplake/client/test_client.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,19 @@ class Status:
128128
"--------------------------------------------------------------\n\n\n"
129129
)
130130

131+
completed_no_improvement = (
132+
"--------------------------------------------------------------\n"
133+
"| 1338464cd80cab681bfcfw23 |\n"
134+
"--------------------------------------------------------------\n"
135+
"| status | completed |\n"
136+
"--------------------------------------------------------------\n"
137+
"| progress | eta: 100.3 seconds |\n"
138+
"| | recall@10: 100.0% (+0.0%) |\n"
139+
"--------------------------------------------------------------\n"
140+
"| results | recall@10: 100.0% (+0.0%) |\n"
141+
"--------------------------------------------------------------\n\n\n"
142+
)
143+
131144
failed = (
132145
"--------------------------------------------------------------\n"
133146
"| 1338464cd80cab681bfcfff3 |\n"
@@ -168,15 +181,15 @@ def test_deepmemory_print_status_and_list_jobs(capsys, precomputed_jobs_list):
168181
progress=None,
169182
)
170183
response_schema = JobResponseStatusSchema(response=pending_response)
171-
response_schema.print_status(job_id, recall=None, importvement=None)
184+
response_schema.print_status(job_id, recall=None, improvement=None)
172185
captured = capsys.readouterr()
173186
assert captured.out == Status.pending
174187

175188
# for training that is in progress
176189
job_id = "3218464cd80cab681bfcfff3"
177190
training_response = create_response(job_id=job_id)
178191
response_schema = JobResponseStatusSchema(response=training_response)
179-
response_schema.print_status(job_id, recall="85.5", importvement="2.6")
192+
response_schema.print_status(job_id, recall="85.5", improvement="2.6")
180193
captured = capsys.readouterr()
181194
assert captured.out == Status.training
182195

@@ -187,10 +200,36 @@ def test_deepmemory_print_status_and_list_jobs(capsys, precomputed_jobs_list):
187200
status="completed",
188201
)
189202
response_schema = JobResponseStatusSchema(response=completed_response)
190-
response_schema.print_status(job_id, recall="85.5", importvement="2.6")
203+
response_schema.print_status(job_id, recall="85.5", improvement="2.6")
191204
captured = capsys.readouterr()
192205
assert captured.out == Status.completed
193206

207+
job_id = "1338464cd80cab681bfcfw23"
208+
completed_no_improvement_response = create_response(
209+
job_id=job_id,
210+
status="completed",
211+
progress={
212+
"eta": 100.34,
213+
"last_update_at": "2021-08-31T15:00:00.000000",
214+
"error": None,
215+
"train_recall@10": "87.8%",
216+
"best_recall@10": "100.0% (+0.0)%",
217+
"epoch": 0,
218+
"base_val_recall@10": 0.8292181491851807,
219+
"val_recall@10": "85.5%",
220+
"dataset": "query",
221+
"split": 0,
222+
"loss": -0.05437087118625641,
223+
"delta": 2.572011947631836,
224+
},
225+
)
226+
response_schema = JobResponseStatusSchema(
227+
response=completed_no_improvement_response
228+
)
229+
response_schema.print_status(job_id, recall="0.0", improvement="0.0")
230+
captured = capsys.readouterr()
231+
assert captured.out == Status.completed_no_improvement
232+
194233
# for jobs that failed
195234
job_id = "1338464cd80cab681bfcfff3"
196235
failed_response = create_response(
@@ -204,7 +243,7 @@ def test_deepmemory_print_status_and_list_jobs(capsys, precomputed_jobs_list):
204243
},
205244
)
206245
response_schema = JobResponseStatusSchema(response=failed_response)
207-
response_schema.print_status(job_id, recall=None, importvement=None)
246+
response_schema.print_status(job_id, recall=None, improvement=None)
208247
captured = capsys.readouterr()
209248
assert captured.out == Status.failed
210249

@@ -213,18 +252,21 @@ def test_deepmemory_print_status_and_list_jobs(capsys, precomputed_jobs_list):
213252
training_response,
214253
completed_response,
215254
failed_response,
255+
completed_no_improvement_response,
216256
]
217257
recalls = {
218258
"1238464cd80cab681bfcfff3": None,
219259
"3218464cd80cab681bfcfff3": "85.5",
220260
"2138464cd80cab681bfcfff3": "85.5",
221261
"1338464cd80cab681bfcfff3": None,
262+
"1338464cd80cab681bfcfw23": "0.0",
222263
}
223264
improvements = {
224265
"1238464cd80cab681bfcfff3": None,
225266
"3218464cd80cab681bfcfff3": "2.6",
226267
"2138464cd80cab681bfcfff3": "2.6",
227268
"1338464cd80cab681bfcfff3": None,
269+
"1338464cd80cab681bfcfw23": "0.0",
228270
}
229271
response_schema = JobResponseStatusSchema(response=responses)
230272
output_str = response_schema.print_jobs(

deeplake/client/utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def print_status(
145145
self,
146146
job_id: Union[str, List[str]],
147147
recall: str,
148-
importvement: str,
148+
improvement: str,
149149
):
150150
if not isinstance(job_id, List):
151151
job_id = [job_id]
@@ -161,7 +161,7 @@ def print_status(
161161
indent=" " * 30,
162162
add_vertical_bars=True,
163163
recall=recall,
164-
improvement=importvement,
164+
improvement=improvement,
165165
)
166166

167167
print(line)
@@ -174,7 +174,7 @@ def print_status(
174174
" " * 30,
175175
add_vertical_bars=True,
176176
recall=recall,
177-
improvement=importvement,
177+
improvement=improvement,
178178
)
179179
progress_string = "| {:<27}| {:<30}"
180180
if progress == "None":
@@ -298,6 +298,8 @@ def get_best_recall_improvement(recall, improvement, best_recall):
298298
elif float(improvement) < float(bimprovement):
299299
return brecall, bimprovement
300300
else:
301+
if brecall > recall:
302+
return brecall, bimprovement
301303
return recall, improvement
302304

303305

deeplake/tests/dummy_data/deep_memory/precomputed_jobs_list.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,6 @@ ID STATUS RESULTS PROGRESS
55
2138464cd80cab681bfcfff3 completed recall@10: 85.5% (+2.6%) eta: 100.3 seconds
66
recall@10: 85.5% (+2.6%)
77
1338464cd80cab681bfcfff3 failed not available yet eta: None seconds
8-
error: list indices must beintegers or slices,not str
8+
error: list indices must beintegers or slices,not str
9+
1338464cd80cab681bfcfw23 completed recall@10: 100.0% (+0.0%) eta: 100.3 seconds
10+
recall@10: 100.0% (+0.0%)

0 commit comments

Comments
 (0)