22import random
33from typing import Optional
44
5- def generate_positive_sample (
5+ def prepend_label_to_image (
66 x_input : torch .Tensor ,
7- true_label : int ,
7+ label : int ,
88 num_classes : int ,
99) -> torch .Tensor :
1010 """
11- Generates a positive sample x_pos by embedding the true label into x_input.
11+ Generates a sample by embedding a label into x_input.
1212
1313 The first `num_classes` elements of the output vector are set to 0,
14- except at the index corresponding to the `true_label `, where it's
14+ except at the index corresponding to the `label `, where it's
1515 set to the maximum value of `x_input`. The remaining elements are copied
16- from `x_input`.
16+ from `x_input` starting after the label section .
1717
1818 Args:
1919 x_input: The original flattened input vector (1D Tensor).
20- true_label : The 0-indexed true class label of x_input .
20+ label : The 0-indexed class label to embed .
2121 num_classes: The total number of classes (c).
2222
2323 Returns:
24- A new tensor x_pos with the true label embedded.
24+ A new tensor with the label embedded (SAME SIZE as input) .
2525 """
2626 if not isinstance (x_input , torch .Tensor ) or x_input .ndim != 1 :
2727 raise ValueError ("x_input must be a 1D PyTorch Tensor." )
28- if not (0 <= true_label < num_classes ):
28+ if not (0 <= label < num_classes ):
2929 raise ValueError (
30- f"True label { true_label } is out of bounds for { num_classes } classes."
30+ f"Label { label } is out of bounds for { num_classes } classes."
3131 )
3232 if num_classes <= 0 :
3333 raise ValueError ("num_classes must be positive." )
34-
35-
34+
3635 d = x_input .shape [0 ]
37- m = torch .max (x_input ) if d > 0 else torch .tensor (0.0 , dtype = x_input .dtype ) # Handle empty x_input
38-
39- # Initialize x_pos with zeros, matching dtype and device of x_input
40- x_pos = torch .zeros_like (x_input )
41-
42- # Part 1: Embed the true label in the first `num_classes` elements.
43- # All these elements are 0, except at the index `true_label`.
44- if true_label < min (num_classes , d ): # Ensure true_label is within bounds of the modifiable part
45- x_pos [true_label ] = m
46-
47- # Part 2: Copy the rest of the original input vector.
48- # These are elements from index `num_classes` to `d-1`.
49- if d > num_classes :
50- x_pos [num_classes :] = x_input [num_classes :]
51- # If num_classes >= d, only the first d elements are modified, and the above copy is skipped.
52-
53- return x_pos
54-
55-
56- def generate_negative_sample (
57- x_input : torch .Tensor ,
58- true_label : int ,
59- num_classes : int ,
60- false_label_override : Optional [int ] = None ,
61- ) -> torch .Tensor :
62- """
63- Generates a negative sample x_neg by embedding a false label into x_input.
64-
65- The first `num_classes` elements of the output vector are set to 0,
66- except at the index corresponding to the `chosen_false_label`, where it's
67- set to the maximum value of `x_input`. The remaining elements are copied
68- from `x_input`.
69-
70- Args:
71- x_input: The original flattened input vector (1D Tensor).
72- true_label: The 0-indexed true class label of x_input.
73- num_classes: The total number of classes (c).
74- false_label_override: Optional. A specific 0-indexed false class label to embed.
75- If None, a false label will be chosen randomly, ensuring
76- it's different from `true_label`. This parameter can be
77- used if implementing a "hard labeling" strategy externally.
78- Returns:
79- A new tensor x_neg with the false label embedded.
80- """
81- if not isinstance (x_input , torch .Tensor ) or x_input .ndim != 1 :
82- raise ValueError ("x_input must be a 1D PyTorch Tensor." )
83- if not (0 <= true_label < num_classes ):
36+ if num_classes > d :
8437 raise ValueError (
85- f"True label { true_label } is out of bounds for { num_classes } classes. "
38+ f"num_classes ( { num_classes } ) cannot be larger than input size ( { d } ) "
8639 )
87- if num_classes <= 0 :
88- raise ValueError ("num_classes must be positive." )
8940
90- chosen_false_label : int
91- if false_label_override is not None :
92- chosen_false_label = false_label_override
93- if not (0 <= chosen_false_label < num_classes ):
94- raise ValueError (
95- f"Provided false_label_override { chosen_false_label } is out of bounds for { num_classes } classes."
96- )
97- if chosen_false_label == true_label :
98- raise ValueError (
99- f"Provided false_label_override { chosen_false_label } cannot be the same as true_label { true_label } ."
100- )
101- else :
102- if num_classes <= 1 :
103- raise ValueError (
104- "Cannot randomly choose a distinct false label with less than 2 classes."
105- )
106- possible_false_labels = [i for i in range (num_classes ) if i != true_label ]
107- if not possible_false_labels : # Should be caught by num_classes <= 1
108- raise ValueError (f"No available false labels to choose from for true_label { true_label } with { num_classes } classes." )
109- chosen_false_label = random .choice (possible_false_labels )
41+ m = torch .max (x_input ) if d > 0 else torch .tensor (0.0 , dtype = x_input .dtype )
11042
111- d = x_input . shape [ 0 ]
112- m = torch .max (x_input ) if d > 0 else torch . tensor ( 0.0 , dtype = x_input . dtype ) # Handle empty x_input
43+ # FIX: Create output tensor with SAME SIZE as input (not larger)
44+ x_output = torch .zeros_like (x_input ) # Same size as x_input
11345
114- # Initialize x_neg with zeros, matching dtype and device of x_input
115- x_neg = torch . zeros_like ( x_input )
46+ # Part 1: Embed the label in the first `num_classes` elements
47+ x_output [ label ] = m
11648
117- # Part 1: Embed the false label in the first `num_classes` elements.
118- if chosen_false_label < min (num_classes , d ): # Ensure chosen_false_label is within bounds
119- x_neg [chosen_false_label ] = m
120-
121- # Part 2: Copy the rest of the original input vector.
49+ # Part 2: Copy the remaining original input elements (skip first num_classes)
12250 if d > num_classes :
123- x_neg [num_classes :] = x_input [num_classes :]
124-
125- return x_neg
51+ x_output [num_classes :] = x_input [num_classes :]
12652
53+ return x_output
12754
12855
12956# --- Example Usage (for demonstration if you run this file directly) ---
13057if __name__ == "__main__" :
13158 # Example: 10 features in original input, 4 classes
132- original_x = torch .rand (10 ) # Random data
59+ original_x = torch .rand (10 ) # Random data
13360 true_class_label = 1
13461 total_classes = 4
13562
13663 print (f"Original x_input: { original_x } " )
64+ print (f"Original shape: { original_x .shape } " )
13765 print (f"True label: { true_class_label } " )
13866 print (f"Num classes: { total_classes } " )
139- print ("-" * 30 )
67+ print ("-" * 50 )
14068
141- x_positive = generate_positive_sample (
69+ # Positive sample (true label)
70+ x_positive = prepend_label_to_image (
14271 original_x , true_class_label , total_classes
14372 )
14473 print (f"x_pos (true label { true_class_label } embedded): { x_positive } " )
145- print ("-" * 30 )
146-
147- x_negative_random = generate_negative_sample (
148- original_x , true_class_label , total_classes
74+ print (f"x_pos shape: { x_positive .shape } " ) # Should be [10] - same as input
75+ print ("-" * 50 )
76+
77+ # Negative sample (random false label)
78+ available_labels = [i for i in range (total_classes ) if i != true_class_label ]
79+ rand_negative_label = random .choice (available_labels )
80+
81+ x_negative_random = prepend_label_to_image (
82+ original_x , rand_negative_label , total_classes
14983 )
150- print (f"x_neg (random false label embedded): { x_negative_random } " )
151- print ("-" * 30 )
152-
153- specific_false = 3
154- if specific_false == true_class_label : # Ensure it's actually false for the example
155- specific_false = 0 if true_class_label != 0 else 2
156-
157- x_negative_specific = generate_negative_sample (
158- original_x , true_class_label , total_classes , false_label_override = specific_false
159- )
160- print (f"x_neg (specific false label { specific_false } embedded): { x_negative_specific } " )
161- print ("-" * 30 )
162-
163- # Edge case: num_classes > len(x_input)
164- short_x = torch .tensor ([0.1 , 0.9 ])
165- true_short_label = 0
166- classes_short = 3
167- print (f"Short Original x_input: { short_x } " )
168- x_pos_short = generate_positive_sample (short_x , true_short_label , classes_short )
169- print (f"x_pos_short (true label { true_short_label } , num_classes { classes_short } ): { x_pos_short } " )
170- x_neg_short = generate_negative_sample (short_x , true_short_label , classes_short , false_label_override = 1 )
171- print (f"x_neg_short (false label 1, num_classes { classes_short } ): { x_neg_short } " )
84+ print (f"x_neg (random false label { rand_negative_label } embedded): { x_negative_random } " )
85+ print (f"x_neg shape: { x_negative_random .shape } " ) # Should be [10] - same as input
86+ print ("-" * 50 )
87+
88+ # Test with MNIST-like data
89+ mnist_like = torch .rand (784 ) # Like MNIST flattened
90+ num_classes_mnist = 5
91+
92+ print (f"MNIST-like original: shape { mnist_like .shape } " )
93+ x_mnist_pos = prepend_label_to_image (mnist_like , 2 , num_classes_mnist )
94+ print (f"MNIST-like after embedding label 2: shape { x_mnist_pos .shape } " ) # Should be [784]
95+ print (f"First 10 elements: { x_mnist_pos [:10 ]} " )
96+ print (f"Label embedded at position 2: { x_mnist_pos [2 ]} " ) # Should be max value
97+ print (f"Positions 0,1,3,4 should be 0: { x_mnist_pos [[0 ,1 ,3 ,4 ]]} " )
0 commit comments