11use anyhow:: Result ;
22use notify:: { Config , Event , EventKind , RecommendedWatcher , RecursiveMode , Watcher } ;
33use std:: fs;
4- use std:: fs:: File ;
4+ use std:: fs:: OpenOptions ;
55use std:: io:: Write ;
66use std:: path:: Path ;
77use std:: process;
@@ -173,9 +173,24 @@ pub fn get(kind: LockKind, folder: &str) -> Lock {
173173 let location = lib_dir. join ( kind. file_name ( ) ) ;
174174 let pid = process:: id ( ) ;
175175
176- // When a lockfile already exists we parse its PID: if the process is still alive we refuse to
177- // proceed, otherwise we will overwrite the stale lock with our own PID.
178176 loop {
177+ if let Err ( e) = fs:: create_dir_all ( & lib_dir) {
178+ return Lock :: Error ( Error :: WritingLockfile ( e) ) ;
179+ }
180+
181+ match OpenOptions :: new ( ) . write ( true ) . create_new ( true ) . open ( & location) {
182+ Ok ( mut file) => {
183+ return match file. write ( pid. to_string ( ) . as_bytes ( ) ) {
184+ Ok ( _) => Lock :: Aquired ( pid) ,
185+ Err ( e) => Lock :: Error ( Error :: WritingLockfile ( e) ) ,
186+ } ;
187+ }
188+ Err ( e) if e. kind ( ) == std:: io:: ErrorKind :: AlreadyExists => ( ) ,
189+ Err ( e) => return Lock :: Error ( Error :: WritingLockfile ( e) ) ,
190+ }
191+
192+ // When a lockfile already exists we parse its PID: if the process is still alive we refuse to
193+ // proceed, otherwise we remove the stale lock and try the atomic create again.
179194 match fs:: read_to_string ( & location) {
180195 Ok ( contents) => match contents. parse :: < u32 > ( ) {
181196 Ok ( parsed_pid) if pid_matches_current_process ( parsed_pid) => match kind {
@@ -188,26 +203,17 @@ pub fn get(kind: LockKind, folder: &str) -> Lock {
188203 }
189204 LockKind :: Watch => return Lock :: Error ( Error :: Locked ( parsed_pid) ) ,
190205 } ,
191- Ok ( _) => break ,
206+ Ok ( _) => match fs:: remove_file ( & location) {
207+ Ok ( _) => continue ,
208+ Err ( e) if e. kind ( ) == std:: io:: ErrorKind :: NotFound => continue ,
209+ Err ( e) => return Lock :: Error ( Error :: ReadingLockfile ( kind, e) ) ,
210+ } ,
192211 Err ( e) => return Lock :: Error ( Error :: ParsingLockfile ( e) ) ,
193212 } ,
194- Err ( e) if e. kind ( ) == std:: io:: ErrorKind :: NotFound => break ,
213+ Err ( e) if e. kind ( ) == std:: io:: ErrorKind :: NotFound => continue ,
195214 Err ( e) => return Lock :: Error ( Error :: ReadingLockfile ( kind, e) ) ,
196215 }
197216 }
198-
199- if let Err ( e) = fs:: create_dir_all ( & lib_dir) {
200- return Lock :: Error ( Error :: WritingLockfile ( e) ) ;
201- }
202-
203- // Rewrite the lockfile with our own PID.
204- match File :: create ( & location) {
205- Ok ( mut file) => match file. write ( pid. to_string ( ) . as_bytes ( ) ) {
206- Ok ( _) => Lock :: Aquired ( pid) ,
207- Err ( e) => Lock :: Error ( Error :: WritingLockfile ( e) ) ,
208- } ,
209- Err ( e) => Lock :: Error ( Error :: WritingLockfile ( e) ) ,
210- }
211217}
212218
213219pub fn get_lock_or_exit ( kind : LockKind , folder : & str ) -> Lock {
@@ -239,6 +245,7 @@ pub fn drop_lock(kind: LockKind, folder: &str) -> Result<()> {
239245mod tests {
240246 use super :: * ;
241247 use std:: fs;
248+ use std:: sync:: { Arc , Barrier } ;
242249 use std:: thread;
243250 use std:: time:: Duration ;
244251 use tempfile:: TempDir ;
@@ -291,6 +298,37 @@ mod tests {
291298 ) ;
292299 }
293300
301+ #[ test]
302+ fn only_one_concurrent_caller_acquires_lock ( ) {
303+ let temp_dir = TempDir :: new ( ) . expect ( "temp dir should be created" ) ;
304+ let project_folder = temp_dir. path ( ) . join ( "project" ) ;
305+ fs:: create_dir ( & project_folder) . expect ( "project folder should be created" ) ;
306+
307+ let caller_count = 12 ;
308+ let start = Arc :: new ( Barrier :: new ( caller_count) ) ;
309+ let handles = ( 0 ..caller_count)
310+ . map ( |_| {
311+ let start = Arc :: clone ( & start) ;
312+ let project_folder = project_folder. clone ( ) ;
313+ thread:: spawn ( move || {
314+ start. wait ( ) ;
315+ get (
316+ LockKind :: Watch ,
317+ project_folder. to_str ( ) . expect ( "path should be valid" ) ,
318+ )
319+ } )
320+ } )
321+ . collect :: < Vec < _ > > ( ) ;
322+
323+ let acquired_count = handles
324+ . into_iter ( )
325+ . map ( |handle| handle. join ( ) . expect ( "lock thread should complete" ) )
326+ . filter ( |lock| matches ! ( lock, Lock :: Aquired ( _) ) )
327+ . count ( ) ;
328+
329+ assert_eq ! ( acquired_count, 1 ) ;
330+ }
331+
294332 #[ test]
295333 fn ignores_stale_lock_for_unrelated_process_name ( ) {
296334 let temp_dir = TempDir :: new ( ) . expect ( "temp dir should be created" ) ;
0 commit comments