Skip to content

Commit 36ee769

Browse files
author
Kevin Chang
committed
with bindsnet
1 parent 08e02c9 commit 36ee769

3 files changed

Lines changed: 57 additions & 553 deletions

File tree

bindsnet/datasets/contrastive_transforms.py

Lines changed: 48 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -2,170 +2,96 @@
22
import random
33
from typing import Optional
44

5-
def generate_positive_sample(
5+
def prepend_label_to_image(
66
x_input: torch.Tensor,
7-
true_label: int,
7+
label: int,
88
num_classes: int,
99
) -> torch.Tensor:
1010
"""
11-
Generates a positive sample x_pos by embedding the true label into x_input.
11+
Generates a sample by embedding a label into x_input.
1212
1313
The first `num_classes` elements of the output vector are set to 0,
14-
except at the index corresponding to the `true_label`, where it's
14+
except at the index corresponding to the `label`, where it's
1515
set to the maximum value of `x_input`. The remaining elements are copied
16-
from `x_input`.
16+
from `x_input` starting after the label section.
1717
1818
Args:
1919
x_input: The original flattened input vector (1D Tensor).
20-
true_label: The 0-indexed true class label of x_input.
20+
label: The 0-indexed class label to embed.
2121
num_classes: The total number of classes (c).
2222
2323
Returns:
24-
A new tensor x_pos with the true label embedded.
24+
A new tensor with the label embedded (SAME SIZE as input).
2525
"""
2626
if not isinstance(x_input, torch.Tensor) or x_input.ndim != 1:
2727
raise ValueError("x_input must be a 1D PyTorch Tensor.")
28-
if not (0 <= true_label < num_classes):
28+
if not (0 <= label < num_classes):
2929
raise ValueError(
30-
f"True label {true_label} is out of bounds for {num_classes} classes."
30+
f"Label {label} is out of bounds for {num_classes} classes."
3131
)
3232
if num_classes <= 0:
3333
raise ValueError("num_classes must be positive.")
34-
35-
34+
3635
d = x_input.shape[0]
37-
m = torch.max(x_input) if d > 0 else torch.tensor(0.0, dtype=x_input.dtype) # Handle empty x_input
38-
39-
# Initialize x_pos with zeros, matching dtype and device of x_input
40-
x_pos = torch.zeros_like(x_input)
41-
42-
# Part 1: Embed the true label in the first `num_classes` elements.
43-
# All these elements are 0, except at the index `true_label`.
44-
if true_label < min(num_classes, d): # Ensure true_label is within bounds of the modifiable part
45-
x_pos[true_label] = m
46-
47-
# Part 2: Copy the rest of the original input vector.
48-
# These are elements from index `num_classes` to `d-1`.
49-
if d > num_classes:
50-
x_pos[num_classes:] = x_input[num_classes:]
51-
# If num_classes >= d, only the first d elements are modified, and the above copy is skipped.
52-
53-
return x_pos
54-
55-
56-
def generate_negative_sample(
57-
x_input: torch.Tensor,
58-
true_label: int,
59-
num_classes: int,
60-
false_label_override: Optional[int] = None,
61-
) -> torch.Tensor:
62-
"""
63-
Generates a negative sample x_neg by embedding a false label into x_input.
64-
65-
The first `num_classes` elements of the output vector are set to 0,
66-
except at the index corresponding to the `chosen_false_label`, where it's
67-
set to the maximum value of `x_input`. The remaining elements are copied
68-
from `x_input`.
69-
70-
Args:
71-
x_input: The original flattened input vector (1D Tensor).
72-
true_label: The 0-indexed true class label of x_input.
73-
num_classes: The total number of classes (c).
74-
false_label_override: Optional. A specific 0-indexed false class label to embed.
75-
If None, a false label will be chosen randomly, ensuring
76-
it's different from `true_label`. This parameter can be
77-
used if implementing a "hard labeling" strategy externally.
78-
Returns:
79-
A new tensor x_neg with the false label embedded.
80-
"""
81-
if not isinstance(x_input, torch.Tensor) or x_input.ndim != 1:
82-
raise ValueError("x_input must be a 1D PyTorch Tensor.")
83-
if not (0 <= true_label < num_classes):
36+
if num_classes > d:
8437
raise ValueError(
85-
f"True label {true_label} is out of bounds for {num_classes} classes."
38+
f"num_classes ({num_classes}) cannot be larger than input size ({d})"
8639
)
87-
if num_classes <= 0:
88-
raise ValueError("num_classes must be positive.")
8940

90-
chosen_false_label: int
91-
if false_label_override is not None:
92-
chosen_false_label = false_label_override
93-
if not (0 <= chosen_false_label < num_classes):
94-
raise ValueError(
95-
f"Provided false_label_override {chosen_false_label} is out of bounds for {num_classes} classes."
96-
)
97-
if chosen_false_label == true_label:
98-
raise ValueError(
99-
f"Provided false_label_override {chosen_false_label} cannot be the same as true_label {true_label}."
100-
)
101-
else:
102-
if num_classes <= 1:
103-
raise ValueError(
104-
"Cannot randomly choose a distinct false label with less than 2 classes."
105-
)
106-
possible_false_labels = [i for i in range(num_classes) if i != true_label]
107-
if not possible_false_labels: # Should be caught by num_classes <= 1
108-
raise ValueError(f"No available false labels to choose from for true_label {true_label} with {num_classes} classes.")
109-
chosen_false_label = random.choice(possible_false_labels)
41+
m = torch.max(x_input) if d > 0 else torch.tensor(0.0, dtype=x_input.dtype)
11042

111-
d = x_input.shape[0]
112-
m = torch.max(x_input) if d > 0 else torch.tensor(0.0, dtype=x_input.dtype) # Handle empty x_input
43+
# FIX: Create output tensor with SAME SIZE as input (not larger)
44+
x_output = torch.zeros_like(x_input) # Same size as x_input
11345

114-
# Initialize x_neg with zeros, matching dtype and device of x_input
115-
x_neg = torch.zeros_like(x_input)
46+
# Part 1: Embed the label in the first `num_classes` elements
47+
x_output[label] = m
11648

117-
# Part 1: Embed the false label in the first `num_classes` elements.
118-
if chosen_false_label < min(num_classes, d): # Ensure chosen_false_label is within bounds
119-
x_neg[chosen_false_label] = m
120-
121-
# Part 2: Copy the rest of the original input vector.
49+
# Part 2: Copy the remaining original input elements (skip first num_classes)
12250
if d > num_classes:
123-
x_neg[num_classes:] = x_input[num_classes:]
124-
125-
return x_neg
51+
x_output[num_classes:] = x_input[num_classes:]
12652

53+
return x_output
12754

12855

12956
# --- Example Usage (for demonstration if you run this file directly) ---
13057
if __name__ == "__main__":
13158
# Example: 10 features in original input, 4 classes
132-
original_x = torch.rand(10) # Random data
59+
original_x = torch.rand(10) # Random data
13360
true_class_label = 1
13461
total_classes = 4
13562

13663
print(f"Original x_input: {original_x}")
64+
print(f"Original shape: {original_x.shape}")
13765
print(f"True label: {true_class_label}")
13866
print(f"Num classes: {total_classes}")
139-
print("-" * 30)
67+
print("-" * 50)
14068

141-
x_positive = generate_positive_sample(
69+
# Positive sample (true label)
70+
x_positive = prepend_label_to_image(
14271
original_x, true_class_label, total_classes
14372
)
14473
print(f"x_pos (true label {true_class_label} embedded): {x_positive}")
145-
print("-" * 30)
146-
147-
x_negative_random = generate_negative_sample(
148-
original_x, true_class_label, total_classes
74+
print(f"x_pos shape: {x_positive.shape}") # Should be [10] - same as input
75+
print("-" * 50)
76+
77+
# Negative sample (random false label)
78+
available_labels = [i for i in range(total_classes) if i != true_class_label]
79+
rand_negative_label = random.choice(available_labels)
80+
81+
x_negative_random = prepend_label_to_image(
82+
original_x, rand_negative_label, total_classes
14983
)
150-
print(f"x_neg (random false label embedded): {x_negative_random}")
151-
print("-" * 30)
152-
153-
specific_false = 3
154-
if specific_false == true_class_label: # Ensure it's actually false for the example
155-
specific_false = 0 if true_class_label !=0 else 2
156-
157-
x_negative_specific = generate_negative_sample(
158-
original_x, true_class_label, total_classes, false_label_override=specific_false
159-
)
160-
print(f"x_neg (specific false label {specific_false} embedded): {x_negative_specific}")
161-
print("-" * 30)
162-
163-
# Edge case: num_classes > len(x_input)
164-
short_x = torch.tensor([0.1, 0.9])
165-
true_short_label = 0
166-
classes_short = 3
167-
print(f"Short Original x_input: {short_x}")
168-
x_pos_short = generate_positive_sample(short_x, true_short_label, classes_short)
169-
print(f"x_pos_short (true label {true_short_label}, num_classes {classes_short}): {x_pos_short}")
170-
x_neg_short = generate_negative_sample(short_x, true_short_label, classes_short, false_label_override=1)
171-
print(f"x_neg_short (false label 1, num_classes {classes_short}): {x_neg_short}")
84+
print(f"x_neg (random false label {rand_negative_label} embedded): {x_negative_random}")
85+
print(f"x_neg shape: {x_negative_random.shape}") # Should be [10] - same as input
86+
print("-" * 50)
87+
88+
# Test with MNIST-like data
89+
mnist_like = torch.rand(784) # Like MNIST flattened
90+
num_classes_mnist = 5
91+
92+
print(f"MNIST-like original: shape {mnist_like.shape}")
93+
x_mnist_pos = prepend_label_to_image(mnist_like, 2, num_classes_mnist)
94+
print(f"MNIST-like after embedding label 2: shape {x_mnist_pos.shape}") # Should be [784]
95+
print(f"First 10 elements: {x_mnist_pos[:10]}")
96+
print(f"Label embedded at position 2: {x_mnist_pos[2]}") # Should be max value
97+
print(f"Positions 0,1,3,4 should be 0: {x_mnist_pos[[0,1,3,4]]}")

0 commit comments

Comments
 (0)