diff --git a/src/ProvidedTypes.fs b/src/ProvidedTypes.fs index d755119..087d185 100644 --- a/src/ProvidedTypes.fs +++ b/src/ProvidedTypes.fs @@ -2874,16 +2874,19 @@ module internal AssemblyReader = type ILMethodDefs(larr: Lazy) = - let mutable lmap = null + let mutable lmap : Dictionary = null + let syncObj = obj() let getmap() = - if isNull lmap then - lmap <- Dictionary() - for y in larr.Force() do - let key = y.Name - match lmap.TryGetValue key with - | true, lmpak -> lmap.[key] <- Array.append [| y |] lmpak - | false, _ -> lmap.[key] <- [| y |] - lmap + lock syncObj (fun () -> + if isNull lmap then + let m = Dictionary() + for y in larr.Force() do + let key = y.Name + match m.TryGetValue key with + | true, lmpak -> m.[key] <- Array.append [| y |] lmpak + | false, _ -> m.[key] <- [| y |] + lmap <- m + lmap) member __.Entries = larr.Force() member __.FindByName nm = @@ -3097,14 +3100,16 @@ module internal AssemblyReader = and ILTypeDefs(larr: Lazy<(string uoption * string * Lazy)[]>) = - let mutable lmap = null + let mutable lmap : Dictionary> = null + let syncObj = obj() let getmap() = - if isNull lmap then - lmap <- Dictionary() - for (nsp, nm, ltd) in larr.Force() do - let key = nsp, nm - lmap.[key] <- ltd - lmap + lock syncObj (fun () -> + if isNull lmap then + let m = Dictionary() + for (nsp, nm, ltd) in larr.Force() do + m.[(nsp, nm)] <- ltd + lmap <- m + lmap) member __.Entries = [| for (_, _, td) in larr.Force() -> td.Force() |] @@ -3142,14 +3147,16 @@ module internal AssemblyReader = override x.ToString() = "fwd " + x.Name and ILExportedTypesAndForwarders(larr:Lazy) = - let mutable lmap = null + let mutable lmap : Dictionary = null + let syncObj = obj() let getmap() = - if isNull lmap then - lmap <- Dictionary() - for ltd in larr.Force() do - let key = ltd.Namespace, ltd.Name - lmap.[key] <- ltd - lmap + lock syncObj (fun () -> + if isNull lmap then + let m = Dictionary() + for ltd in larr.Force() do + m.[(ltd.Namespace, ltd.Name)] <- ltd + lmap <- m + lmap) member __.Entries = larr.Force() member __.TryFindByName (nsp, nm) = match getmap().TryGetValue ((nsp, nm)) with true, v -> Some v | false, _ -> None @@ -4579,34 +4586,29 @@ module internal AssemblyReader = let mkCacheInt32 lowMem _infile _nm _sz = if lowMem then (fun f x -> f x) else - let cache = ref null + let cache = Dictionary() + let syncObj = obj() fun f (idx:int32) -> - let cache = - match !cache with - | null -> cache := new Dictionary(11) - | _ -> () - !cache - let mutable res = Unchecked.defaultof<_> - let ok = cache.TryGetValue(idx, &res) - if ok then - res - else - let res = f idx - cache.[idx] <- res; - res + lock syncObj (fun () -> + let mutable res = Unchecked.defaultof<_> + if cache.TryGetValue(idx, &res) then res + else + let v = f idx + cache.[idx] <- v + v) let mkCacheGeneric lowMem _inbase _nm _sz = if lowMem then (fun f x -> f x) else - let cache = ref null + let cache = Dictionary<'T, _>() + let syncObj = obj() fun f (idx :'T) -> - let cache = - match !cache with - | null -> cache := new Dictionary<_, _>(11 (* sz:int *) ) - | _ -> () - !cache - match cache.TryGetValue idx with - | true, cached -> cached - | false, _ -> let res = f idx in cache.[idx] <- res; res + lock syncObj (fun () -> + match cache.TryGetValue idx with + | true, cached -> cached + | false, _ -> + let v = f idx + cache.[idx] <- v + v) let seekFindRow numRows rowChooser = let mutable i = 1 @@ -7012,14 +7014,10 @@ namespace ProviderImplementation.ProvidedTypes // Unique wrapped type definition objects must be translated to unique wrapper objects, based // on object identity. type TxTable<'T2>() = - let tab = Dictionary() + let tab = System.Collections.Concurrent.ConcurrentDictionary>() member __.Get inp f = - match tab.TryGetValue inp with - | true, tabVal -> tabVal - | false, _ -> - let res = f() - tab.[inp] <- res - res + let lazyVal = tab.GetOrAdd(inp, fun _ -> lazy (f())) + lazyVal.Value member __.ContainsKey inp = tab.ContainsKey inp diff --git a/tests/BasicGenerativeProvisionTests.fs b/tests/BasicGenerativeProvisionTests.fs index 36d7ec1..699dc31 100644 --- a/tests/BasicGenerativeProvisionTests.fs +++ b/tests/BasicGenerativeProvisionTests.fs @@ -511,3 +511,46 @@ let ``Generative custom attribute with named property argument encodes and round Assert.Equal(1, namedArgs.Length) Assert.Equal("Name", namedArgs.[0].MemberName) Assert.Equal("MyProp", namedArgs.[0].TypedValue.Value :?> string) + +[] +let ``TargetTypeDefinition member-wrapper caches are thread-safe under parallel access``() = + // Regression test for https://github.com/fsprojects/FSharp.TypeProviders.SDK/issues/481 + // PR #471 introduced lazy caches in TargetTypeDefinition. When multiple threads call + // GetConstructors/GetMethods/etc. concurrently on the same generated type the underlying + // shared caches must not corrupt. Run 8 parallel threads each interrogating every member + // kind on the same TargetTypeDefinition; if any internal collection races, the dictionaries + // will throw InvalidOperationException. + let runtimeAssemblyRefs = Targets.DotNetStandard20FSharpRefs() + let runtimeAssembly = runtimeAssemblyRefs.[0] + let cfg = Testing.MakeSimulatedTypeProviderConfig(__SOURCE_DIRECTORY__, runtimeAssembly, runtimeAssemblyRefs) + let staticArgs = [| box 5; box 6 |] + let tp = GenerativePropertyProviderWithStaticParams cfg :> TypeProviderForNamespaces + let providedNamespace = tp.Namespaces.[0] + let providedTypes = providedNamespace.GetTypes() + let providedType = providedTypes.[0] + let typeName = providedType.Name + (staticArgs |> Seq.map (fun s -> ",\"" + s.ToString() + "\"") |> Seq.reduce (+)) + let t = (tp :> ITypeProvider).ApplyStaticArguments(providedType, [| typeName |], staticArgs) + let assemContents = (tp :> ITypeProvider).GetGeneratedAssemblyContents(t.Assembly) + let assem = tp.TargetContext.ReadRelatedAssembly(assemContents) + let typeName2 = providedType.Namespace + "." + typeName + let targetType = assem.GetType(typeName2) + Assert.NotNull(targetType) + + let bf = BindingFlags.Public ||| BindingFlags.NonPublic ||| BindingFlags.Instance ||| BindingFlags.Static + let errors = System.Collections.Concurrent.ConcurrentBag() + let threads = + [| for _ in 1..8 -> + System.Threading.Thread(fun () -> + try + for _ in 1..50 do + targetType.GetConstructors(bf) |> ignore + targetType.GetMethods(bf) |> ignore + targetType.GetFields(bf) |> ignore + targetType.GetProperties(bf) |> ignore + targetType.GetEvents(bf) |> ignore + targetType.GetNestedTypes(bf) |> ignore + with ex -> + errors.Add(ex)) |] + for th in threads do th.Start() + for th in threads do th.Join() + Assert.True(errors.IsEmpty, sprintf "Thread-safety violations: %A" (errors |> Seq.toList))