@@ -250,11 +250,15 @@ def __init__(
250250 tracer : Tracer ,
251251 file_system : AbstractFileSystem ,
252252 prefix_path : str ,
253+ * ,
253254 params : AnyDict | None = None ,
254255 metrics : MetricDict | None = None ,
255256 run_id : str | None = None ,
256257 tags : t .Sequence [str ] | None = None ,
258+ autolog : bool = True ,
257259 ) -> None :
260+ self .autolog = autolog
261+
258262 self ._params = params or {}
259263 self ._metrics = metrics or {}
260264 self ._objects : dict [str , Object ] = {}
@@ -486,9 +490,8 @@ def log_input(
486490 value ,
487491 label = label ,
488492 event_name = EVENT_NAME_OBJECT_INPUT ,
489- ** attributes ,
490493 )
491- self ._inputs .append (ObjectRef (name , label = label , hash = hash_ ))
494+ self ._inputs .append (ObjectRef (name , label = label , hash = hash_ , attributes = attributes ))
492495
493496 def log_artifact (
494497 self ,
@@ -528,6 +531,7 @@ def log_metric(
528531 origin : t .Any | None = None ,
529532 timestamp : datetime | None = None ,
530533 mode : MetricAggMode | None = None ,
534+ prefix : str | None = None ,
531535 attributes : JsonDict | None = None ,
532536 ) -> None : ...
533537
@@ -539,6 +543,7 @@ def log_metric(
539543 * ,
540544 origin : t .Any | None = None ,
541545 mode : MetricAggMode | None = None ,
546+ prefix : str | None = None ,
542547 ) -> None : ...
543548
544549 def log_metric (
@@ -550,6 +555,7 @@ def log_metric(
550555 origin : t .Any | None = None ,
551556 timestamp : datetime | None = None ,
552557 mode : MetricAggMode | None = None ,
558+ prefix : str | None = None ,
553559 attributes : JsonDict | None = None ,
554560 ) -> None :
555561 metric = (
@@ -560,6 +566,10 @@ def log_metric(
560566 )
561567 )
562568
569+ key = re .sub (r"[^\w/]+" , "_" , key .lower ())
570+ if prefix is not None :
571+ key = f"{ prefix } .{ key } "
572+
563573 if origin is not None :
564574 origin_hash = self .log_object (
565575 origin ,
@@ -590,9 +600,8 @@ def log_output(
590600 value ,
591601 label = label ,
592602 event_name = EVENT_NAME_OBJECT_OUTPUT ,
593- ** attributes ,
594603 )
595- self ._outputs .append (ObjectRef (name , label = label , hash = hash_ ))
604+ self ._outputs .append (ObjectRef (name , label = label , hash = hash_ , attributes = attributes ))
596605
597606
598607class TaskSpan (Span , t .Generic [R ]):
@@ -694,9 +703,8 @@ def log_output(
694703 value ,
695704 label = label ,
696705 event_name = EVENT_NAME_OBJECT_OUTPUT ,
697- ** attributes ,
698706 )
699- self ._outputs .append (ObjectRef (name , label = label , hash = hash_ ))
707+ self ._outputs .append (ObjectRef (name , label = label , hash = hash_ , attributes = attributes ))
700708 return hash_
701709
702710 @property
@@ -726,9 +734,8 @@ def log_input(
726734 value ,
727735 label = label ,
728736 event_name = EVENT_NAME_OBJECT_INPUT ,
729- ** attributes ,
730737 )
731- self ._inputs .append (ObjectRef (name , label = label , hash = hash_ ))
738+ self ._inputs .append (ObjectRef (name , label = label , hash = hash_ , attributes = attributes ))
732739 return hash_
733740
734741 @property
@@ -777,6 +784,8 @@ def log_metric(
777784 )
778785 )
779786
787+ key = re .sub (r"[^\w/]+" , "_" , key .lower ())
788+
780789 if origin is not None :
781790 origin_hash = self .run .log_object (
782791 origin ,
@@ -795,7 +804,7 @@ def log_metric(
795804 #
796805 # Don't include `source` and `mode` as we handled it here.
797806 if (run := current_run_span .get ()) is not None :
798- run .log_metric (f" { self . _label } . { key } " , metric )
807+ run .log_metric (key , metric , prefix = self . _label )
799808
800809 def get_average_metric_value (self , key : str | None = None ) -> float :
801810 metrics = (
0 commit comments