@@ -36,18 +36,82 @@ def __init__(
3636 max_length : int = 2048 ,
3737 shuffle_seed : int = 42 ,
3838 chat_template_type : ChatTemplateType = ChatTemplateType .QWEN3 ,
39+ display : bool = False ,
3940 ):
4041 self .tokenizer = tokenizer
4142 self .max_length = max_length
4243 self .shuffle_seed = shuffle_seed
4344 self .chat_template_type = chat_template_type
45+ self .display = display
46+ self .display_count = 0 # Track how many samples have been displayed
4447
4548 # Get chat template
4649 template = template_manager .get_template_dict (chat_template_type )
4750 self .user_header = template ["user_header" ]
4851 self .assistant_header = template ["assistant_header" ]
4952 self .system_prompt = template ["system_prompt" ]
5053
54+ def _visualize_loss_mask (
55+ self , input_ids : torch .Tensor , loss_mask : torch .Tensor , conversation : str
56+ ) -> None :
57+ """
58+ Visualize loss_mask with color-coded output.
59+
60+ Args:
61+ input_ids: Token IDs
62+ loss_mask: Loss mask tensor (1 for training, 0 for ignoring)
63+ conversation: Original conversation text
64+ """
65+ # ANSI color codes
66+ RED = "\033 [91m" # For masked out tokens (loss_mask=0)
67+ GREEN = "\033 [92m" # For training tokens (loss_mask=1)
68+ RESET = "\033 [0m" # Reset color
69+ BOLD = "\033 [1m"
70+
71+ rank0_print ("\n " + "=" * 80 )
72+ rank0_print (f"{ BOLD } Loss Mask Visualization{ RESET } " )
73+ rank0_print ("=" * 80 )
74+
75+ # Display legend
76+ rank0_print (f"\n { BOLD } Legend:{ RESET } " )
77+ rank0_print (f"{ GREEN } ■ Green: Training tokens (loss_mask=1){ RESET } " )
78+ rank0_print (f"{ RED } ■ Red: Ignored tokens (loss_mask=0){ RESET } " )
79+
80+ # Display statistics
81+ total_tokens = len (loss_mask )
82+ training_tokens = loss_mask .sum ().item ()
83+ ignored_tokens = total_tokens - training_tokens
84+ training_ratio = training_tokens / total_tokens * 100 if total_tokens > 0 else 0
85+
86+ rank0_print (f"\n { BOLD } Statistics:{ RESET } " )
87+ rank0_print (f"Total tokens: { total_tokens } " )
88+ rank0_print (f"Training tokens: { training_tokens } ({ training_ratio :.2f} %)" )
89+ rank0_print (f"Ignored tokens: { ignored_tokens } ({ 100 - training_ratio :.2f} %)" )
90+
91+ # Display token-by-token visualization
92+ rank0_print (f"\n { BOLD } Token-by-token visualization:{ RESET } " )
93+ rank0_print ("-" * 80 )
94+
95+ decoded_tokens = []
96+ for token_id , mask_value in zip (input_ids , loss_mask ):
97+ token_text = self .tokenizer .decode ([token_id ], skip_special_tokens = False )
98+
99+ # Choose color based on mask value
100+ color = GREEN if mask_value == 1 else RED
101+
102+ # Format token with color
103+ colored_token = f"{ color } { token_text } { RESET } "
104+ decoded_tokens .append (colored_token )
105+
106+ # Print all tokens directly
107+ rank0_print ("" .join (decoded_tokens ))
108+
109+ # Display original conversation for reference
110+ rank0_print (f"\n { BOLD } Original conversation:{ RESET } " )
111+ rank0_print ("-" * 80 )
112+ rank0_print (conversation )
113+ rank0_print ("=" * 80 + "\n " )
114+
51115 def build_dataset (self , datapath : str , num_proc : int = 8 ) -> Dataset :
52116 try :
53117 # Load and shuffle dataset
@@ -67,8 +131,10 @@ def build_dataset(self, datapath: str, num_proc: int = 8) -> Dataset:
67131 desc = "Processing conversations" ,
68132 )
69133
70- # Filter out None results
71- processed_ds = processed_ds .filter (lambda x : x ["input_ids" ] is not None )
134+ # Filter out None results with multiprocessing support
135+ processed_ds = processed_ds .filter (
136+ lambda x : x ["input_ids" ] is not None , num_proc = num_proc
137+ )
72138 processed_ds .set_format (type = "torch" )
73139
74140 return processed_ds
@@ -134,6 +200,11 @@ def _process_single_conversation(
134200 input_ids = torch .tensor (input_ids )
135201 attention_mask = torch .ones_like (input_ids )
136202
203+ # Visualize loss mask if display mode is enabled
204+ if self .display and self .display_count == 0 :
205+ self ._visualize_loss_mask (input_ids , loss_mask , conversation )
206+ self .display_count += 1
207+
137208 return {
138209 "input_ids" : input_ids [None , :],
139210 "attention_mask" : attention_mask [None , :],
@@ -262,6 +333,7 @@ def __init__(
262333 tokenizer : AutoTokenizer ,
263334 model_max_length : int = 2048 ,
264335 chat_template_type : Optional [Union [str , ChatTemplateType ]] = None ,
336+ display : bool = False ,
265337 ):
266338 """
267339 Initialize DatasetManager with DataArguments.
@@ -274,10 +346,12 @@ def __init__(
274346 - ChatTemplateType enum value (e.g., ChatTemplateType.QWEN3)
275347 - String (e.g., "llama", "qwen")
276348 - None (will default to LLAMA)
349+ display: Whether to display loss mask visualization for the first sample
277350 """
278351 self .data_args = data_args
279352 self .tokenizer = tokenizer
280353 self .model_max_length = model_max_length
354+ self .display = display
281355
282356 # Convert chat_template_type to ChatTemplateType enum
283357 if chat_template_type is None :
@@ -293,6 +367,7 @@ def __init__(
293367 max_length = model_max_length ,
294368 shuffle_seed = data_args .shuffle_seed ,
295369 chat_template_type = chat_template_type ,
370+ display = display ,
296371 )
297372
298373 def create_datasets (self ) -> Tuple [Dataset , Optional [Dataset ]]:
@@ -305,8 +380,8 @@ def create_datasets(self) -> Tuple[Dataset, Optional[Dataset]]:
305380 """
306381 # Determine number of processes
307382 num_proc = self .data_args .num_proc
308- if self .data_args . preprocessing_num_workers is not None :
309- num_proc = self . data_args . preprocessing_num_workers
383+ if self .display :
384+ num_proc = None
310385
311386 # Create train dataset
312387 train_dataset = self .dataset_builder .build_dataset (
0 commit comments