44
55class TimestepEmbedder (nn .Module ):
66 """
7- Embeds scalar timesteps into vector representations.
7+ Embeds scalar timesteps into vector representations using a sinusoidal embedding
8+ followed by a multilayer perceptron (MLP).
9+
10+ Parameters
11+ ----------
12+ hidden_size : int
13+ Output dimension of the MLP embedding.
14+ frequency_embedding_size : int, optional
15+ Size of the input frequency embedding, by default 256.
816 """
17+
918 def __init__ (self , hidden_size , frequency_embedding_size = 256 ):
1019 super ().__init__ ()
1120 self .mlp = nn .Sequential (
@@ -19,8 +28,9 @@ def __init__(self, hidden_size, frequency_embedding_size=256):
1928 def timestep_embedding (t , dim , max_period = 10000 ):
2029 """
2130 Create sinusoidal timestep embeddings.
31+
2232 :param t: a 1-D Tensor of N indices, one per batch element.
23- These may be fractional.
33+ These may be fractional.
2434 :param dim: the dimension of the output.
2535 :param max_period: controls the minimum frequency of the embeddings.
2636 :return: an (N, D) Tensor of positional embeddings.
@@ -37,15 +47,37 @@ def timestep_embedding(t, dim, max_period=10000):
3747 return embedding
3848
3949 def forward (self , t ):
50+ """
51+ Forward pass for timestep embedding.
52+
53+ Parameters
54+ ----------
55+ t : torch.Tensor
56+ 1D tensor of scalar timesteps.
57+
58+ Returns
59+ -------
60+ torch.Tensor
61+ The final embedded representation of shape (N, hidden_size).
62+ """
4063 t = t .view (- 1 )
4164 t_freq = self .timestep_embedding (t , self .frequency_embedding_size )
4265 t_emb = self .mlp (t_freq )
4366 return t_emb
4467
4568class CategoricalEmbedder (nn .Module ):
4669 """
47- Embeds categorical conditions such as data sources into vector representations.
48- Also handles label dropout for classifier-free guidance.
70+ Embeds categorical conditions (e.g., data source labels) into vector representations.
71+ Supports label dropout for classifier-free guidance.
72+
73+ Parameters
74+ ----------
75+ num_classes : int
76+ Number of distinct label categories.
77+ hidden_size : int
78+ Size of the embedding vectors.
79+ dropout_prob : float
80+ Probability of label dropout.
4981 """
5082 def __init__ (self , num_classes , hidden_size , dropout_prob ):
5183 super ().__init__ ()
@@ -57,6 +89,18 @@ def __init__(self, num_classes, hidden_size, dropout_prob):
5789 def token_drop (self , labels , force_drop_ids = None ):
5890 """
5991 Drops labels to enable classifier-free guidance.
92+
93+ Parameters
94+ ----------
95+ labels : torch.Tensor
96+ Tensor of integer labels.
97+ force_drop_ids : torch.Tensor or None, optional
98+ Boolean mask to force specific labels to be dropped.
99+
100+ Returns
101+ -------
102+ torch.Tensor
103+ Labels with some entries replaced by a dropout token.
60104 """
61105 if force_drop_ids is None :
62106 drop_ids = torch .rand (labels .shape [0 ], device = labels .device ) < self .dropout_prob
@@ -65,7 +109,24 @@ def token_drop(self, labels, force_drop_ids=None):
65109 labels = torch .where (drop_ids , self .num_classes , labels )
66110 return labels
67111
68- def forward (self , labels , train , force_drop_ids = None , t = None ):
112+ def forward (self , labels , train , force_drop_ids = None ):
113+ """
114+ Forward pass for categorical embedding with optional label dropout.
115+
116+ Parameters
117+ ----------
118+ labels : torch.Tensor
119+ Tensor of categorical labels.
120+ train : bool
121+ Whether the model is in training mode.
122+ force_drop_ids : torch.Tensor or None, optional
123+ Explicit mask for which labels to drop.
124+
125+ Returns
126+ -------
127+ torch.Tensor
128+ Embedded label representations, with optional noise added during training.
129+ """
69130 labels = labels .long ().view (- 1 )
70131 use_dropout = self .dropout_prob > 0
71132 if (train and use_dropout ) or (force_drop_ids is not None ):
@@ -77,6 +138,20 @@ def forward(self, labels, train, force_drop_ids=None, t=None):
77138 return embeddings
78139
79140class ClusterContinuousEmbedder (nn .Module ):
141+ """
142+ Embeds continuous input features into vector representations using a multilayer perceptron (MLP).
143+ Supports optional embedding dropout for classifier-free guidance.
144+
145+ Parameters
146+ ----------
147+ input_size : int
148+ The size of the input features.
149+ hidden_size : int
150+ The size of the output embedding vectors.
151+ dropout_prob : float
152+ Probability of embedding dropout, used for classifier-free guidance.
153+
154+ """
80155 def __init__ (self , input_size , hidden_size , dropout_prob ):
81156 super ().__init__ ()
82157 use_cfg_embedding = dropout_prob > 0
0 commit comments