Skip to content

Commit ad29db0

Browse files
committed
Expose attributes for metrics in . Add example notebook for trainer callback.
1 parent 0a1fb83 commit ad29db0

3 files changed

Lines changed: 135 additions & 3 deletions

File tree

dreadnode/main.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
)
4949
from dreadnode.types import (
5050
AnyDict,
51+
JsonDict,
5152
JsonValue,
5253
)
5354
from dreadnode.util import handle_internal_errors
@@ -759,6 +760,7 @@ def log_metric(
759760
origin: t.Any | None = None,
760761
timestamp: datetime | None = None,
761762
mode: MetricAggMode | None = None,
763+
attributes: JsonDict | None = None,
762764
to: ToObject = "task-or-run",
763765
) -> None:
764766
"""
@@ -788,6 +790,7 @@ def log_metric(
788790
- avg: the average of all reported values for this metric
789791
- sum: the cumulative sum of all reported values for this metric
790792
- count: increment every time this metric is logged - disregard value
793+
attributes: A dictionary of additional attributes to attach to the metric.
791794
to: The target object to log the metric to. Can be "task-or-run" or "run".
792795
Defaults to "task-or-run". If "task-or-run", the metric will be logged
793796
to the current task or run, whichever is the nearest ancestor.
@@ -842,6 +845,7 @@ def log_metric(
842845
origin: t.Any | None = None,
843846
timestamp: datetime | None = None,
844847
mode: MetricAggMode | None = None,
848+
attributes: JsonDict | None = None,
845849
to: ToObject = "task-or-run",
846850
) -> None:
847851
task = current_task_span.get()
@@ -854,7 +858,9 @@ def log_metric(
854858
metric = (
855859
value
856860
if isinstance(value, Metric)
857-
else Metric(float(value), step, timestamp or datetime.now(timezone.utc))
861+
else Metric(
862+
float(value), step, timestamp or datetime.now(timezone.utc), attributes or {}
863+
)
858864
)
859865
target.log_metric(key, metric, origin=origin, mode=mode)
860866

dreadnode/tracing/span.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,7 @@ def log_metric(
527527
origin: t.Any | None = None,
528528
timestamp: datetime | None = None,
529529
mode: MetricAggMode | None = None,
530+
attributes: JsonDict | None = None,
530531
) -> None: ...
531532

532533
@t.overload
@@ -548,11 +549,14 @@ def log_metric(
548549
origin: t.Any | None = None,
549550
timestamp: datetime | None = None,
550551
mode: MetricAggMode | None = None,
552+
attributes: JsonDict | None = None,
551553
) -> None:
552554
metric = (
553555
value
554556
if isinstance(value, Metric)
555-
else Metric(float(value), step, timestamp or datetime.now(timezone.utc))
557+
else Metric(
558+
float(value), step, timestamp or datetime.now(timezone.utc), attributes or {}
559+
)
556560
)
557561

558562
if origin is not None:
@@ -740,6 +744,7 @@ def log_metric(
740744
origin: t.Any | None = None,
741745
timestamp: datetime | None = None,
742746
mode: MetricAggMode | None = None,
747+
attributes: JsonDict | None = None,
743748
) -> None: ...
744749

745750
@t.overload
@@ -761,11 +766,14 @@ def log_metric(
761766
origin: t.Any | None = None,
762767
timestamp: datetime | None = None,
763768
mode: MetricAggMode | None = None,
769+
attributes: JsonDict | None = None,
764770
) -> None:
765771
metric = (
766772
value
767773
if isinstance(value, Metric)
768-
else Metric(float(value), step, timestamp or datetime.now(timezone.utc))
774+
else Metric(
775+
float(value), step, timestamp or datetime.now(timezone.utc), attributes or {}
776+
)
769777
)
770778

771779
if origin is not None:

examples/model_training.ipynb

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"## Model Training Callbacks\n",
8+
"\n",
9+
"You can use the library to log your model training progress to Strikes.\n"
10+
]
11+
},
12+
{
13+
"cell_type": "code",
14+
"execution_count": null,
15+
"metadata": {},
16+
"outputs": [],
17+
"source": [
18+
"import dreadnode as dn\n",
19+
"\n",
20+
"dn.configure(\n",
21+
" token=\"<YOUR API KEY>\", # Replace with your token\n",
22+
")"
23+
]
24+
},
25+
{
26+
"cell_type": "code",
27+
"execution_count": null,
28+
"metadata": {},
29+
"outputs": [],
30+
"source": [
31+
"from datasets import load_dataset\n",
32+
"from transformers import AutoModelForSequenceClassification, AutoTokenizer\n",
33+
"\n",
34+
"# Load dataset\n",
35+
"dataset = load_dataset(\"glue\", \"sst2\")\n",
36+
"tokenizer = AutoTokenizer.from_pretrained(\"distilbert-base-uncased\")\n",
37+
"\n",
38+
"# Take a small portion of the dataset\n",
39+
"dataset[\"train\"] = dataset[\"train\"].select(range(1000))\n",
40+
"dataset[\"validation\"] = dataset[\"validation\"].select(range(1000))\n",
41+
"\n",
42+
"# Preprocessing function\n",
43+
"def preprocess_function(examples):\n",
44+
" return tokenizer(examples[\"sentence\"], truncation=True, padding=\"max_length\")\n",
45+
"\n",
46+
"# Tokenize the dataset\n",
47+
"tokenized_datasets = dataset.map(preprocess_function, batched=True)\n",
48+
"\n",
49+
"# Load model\n",
50+
"model = AutoModelForSequenceClassification.from_pretrained(\"distilbert-base-uncased\", num_labels=2)"
51+
]
52+
},
53+
{
54+
"cell_type": "code",
55+
"execution_count": null,
56+
"metadata": {},
57+
"outputs": [],
58+
"source": [
59+
"from transformers import Trainer, TrainingArguments\n",
60+
"\n",
61+
"from dreadnode.integrations.transformers import DreadnodeCallback\n",
62+
"import dreadnode as dn\n",
63+
"\n",
64+
"# Define training arguments\n",
65+
"training_args = TrainingArguments(\n",
66+
" output_dir=\"./results\",\n",
67+
" learning_rate=2e-5,\n",
68+
" per_device_train_batch_size=6,\n",
69+
" per_device_eval_batch_size=6,\n",
70+
" num_train_epochs=5,\n",
71+
" weight_decay=0.01,\n",
72+
" eval_strategy=\"steps\",\n",
73+
" eval_steps=5,\n",
74+
" load_best_model_at_end=False,\n",
75+
" push_to_hub=False,\n",
76+
" run_name=\"distilbert-sst2-demo\",\n",
77+
")\n",
78+
"\n",
79+
"# Initialize Trainer with RiggingCallback\n",
80+
"trainer = Trainer(\n",
81+
" model=model,\n",
82+
" args=training_args,\n",
83+
" train_dataset=tokenized_datasets[\"train\"],\n",
84+
" eval_dataset=tokenized_datasets[\"validation\"],\n",
85+
" tokenizer=tokenizer,\n",
86+
" callbacks=[DreadnodeCallback(project=\"training\")],\n",
87+
")\n",
88+
"\n",
89+
"# Train the model\n",
90+
"trainer.train()\n",
91+
"\n",
92+
"# Evaluate the model\n",
93+
"trainer.evaluate()"
94+
]
95+
}
96+
],
97+
"metadata": {
98+
"kernelspec": {
99+
"display_name": ".venv",
100+
"language": "python",
101+
"name": "python3"
102+
},
103+
"language_info": {
104+
"codemirror_mode": {
105+
"name": "ipython",
106+
"version": 3
107+
},
108+
"file_extension": ".py",
109+
"mimetype": "text/x-python",
110+
"name": "python",
111+
"nbconvert_exporter": "python",
112+
"pygments_lexer": "ipython3",
113+
"version": "3.10.14"
114+
}
115+
},
116+
"nbformat": 4,
117+
"nbformat_minor": 2
118+
}

0 commit comments

Comments
 (0)