66class DANModel :
77 def __init__ (self , F = 257 , num_speakers = 251 ,
88 layer_size = 600 , embedding_size = 40 ,
9- nonlinearity = 'logistic' ,normalize = False ):
9+ nonlinearity = 'logistic' ,normalize = False ,
10+ device = '/cpu:0' ):
1011 """
1112 Initializes a Deep Attractor Network[i]. Default architecture is the
1213 same as for the Lab41 model and the deep clustering model.
@@ -23,6 +24,7 @@ def __init__(self, F=257, num_speakers=251,
2324 nonlinearity: Nonlinearity to use in BLSTM layers (default logistic)
2425 normalize: Do you normalize vectors coming into the final layer?
2526 (default False)
27+ device: Which device to run the model on
2628 """
2729
2830 self .F = F
@@ -34,28 +36,28 @@ def __init__(self, F=257, num_speakers=251,
3436
3537 self .graph = tf .Graph ()
3638 with self .graph .as_default ():
39+ with tf .device (device ):
40+ # Placeholder tensor for the magnitude spectrogram
41+ self .S = tf .placeholder ("float" , [None , None , self .F ])
3742
38- # Placeholder tensor for the magnitude spectrogram
39- self .S = tf .placeholder ("float" , [None , None , self .F ])
43+ # Placeholder tensor for the input data
44+ self .X = tf .placeholder ("float" , [None , None , self .F ])
4045
41- # Placeholder tensor for the input data
42- self .X = tf .placeholder ("float" , [None , None , self .F ])
46+ # Placeholder tensor for the labels/targets
47+ self .y = tf .placeholder ("float" , [None , None , self .F , None ])
4348
44- # Placeholder tensor for the labels/targets
45- self .y = tf .placeholder ("float" , [None , None , self . F , None ])
49+ # Placeholder for the speaker indicies
50+ self .I = tf .placeholder (tf . int32 , [None ,None ])
4651
47- # Placeholder for the speaker indicies
48- self .I = tf .placeholder (tf .int32 , [None ,None ])
49-
50- # Define the speaker vectors to use during training
51- self .speaker_vectors = tf_utils .weight_variable (
52+ # Define the speaker vectors to use during training
53+ self .speaker_vectors = tf_utils .weight_variable (
5254 [self .num_speakers ,self .embedding_size ],
5355 tf .sqrt (2 / self .embedding_size ))
5456
55- # Model methods
56- self .network
57- self .cost
58- self .optimizer
57+ # Model methods
58+ self .network
59+ self .cost
60+ self .optimizer
5961
6062 # Saver
6163 self .saver = tf .train .Saver ()
0 commit comments