Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 51 additions & 53 deletions src/ProvidedTypes.fs
Original file line number Diff line number Diff line change
Expand Up @@ -2874,16 +2874,19 @@ module internal AssemblyReader =

type ILMethodDefs(larr: Lazy<ILMethodDef[]>) =

let mutable lmap = null
let mutable lmap : Dictionary<string, ILMethodDef[]> = null
let syncObj = obj()
let getmap() =
if isNull lmap then
lmap <- Dictionary()
for y in larr.Force() do
let key = y.Name
match lmap.TryGetValue key with
| true, lmpak -> lmap.[key] <- Array.append [| y |] lmpak
| false, _ -> lmap.[key] <- [| y |]
lmap
lock syncObj (fun () ->
if isNull lmap then
let m = Dictionary()
for y in larr.Force() do
let key = y.Name
match m.TryGetValue key with
| true, lmpak -> m.[key] <- Array.append [| y |] lmpak
| false, _ -> m.[key] <- [| y |]
lmap <- m
lmap)

member __.Entries = larr.Force()
member __.FindByName nm =
Expand Down Expand Up @@ -3097,14 +3100,16 @@ module internal AssemblyReader =

and ILTypeDefs(larr: Lazy<(string uoption * string * Lazy<ILTypeDef>)[]>) =

let mutable lmap = null
let mutable lmap : Dictionary<string uoption * string, Lazy<ILTypeDef>> = null
let syncObj = obj()
let getmap() =
if isNull lmap then
lmap <- Dictionary()
for (nsp, nm, ltd) in larr.Force() do
let key = nsp, nm
lmap.[key] <- ltd
lmap
lock syncObj (fun () ->
if isNull lmap then
let m = Dictionary()
for (nsp, nm, ltd) in larr.Force() do
m.[(nsp, nm)] <- ltd
lmap <- m
lmap)

member __.Entries =
[| for (_, _, td) in larr.Force() -> td.Force() |]
Expand Down Expand Up @@ -3142,14 +3147,16 @@ module internal AssemblyReader =
override x.ToString() = "fwd " + x.Name

and ILExportedTypesAndForwarders(larr:Lazy<ILExportedTypeOrForwarder[]>) =
let mutable lmap = null
let mutable lmap : Dictionary<string uoption * string, ILExportedTypeOrForwarder> = null
let syncObj = obj()
let getmap() =
if isNull lmap then
lmap <- Dictionary()
for ltd in larr.Force() do
let key = ltd.Namespace, ltd.Name
lmap.[key] <- ltd
lmap
lock syncObj (fun () ->
if isNull lmap then
let m = Dictionary()
for ltd in larr.Force() do
m.[(ltd.Namespace, ltd.Name)] <- ltd
lmap <- m
lmap)
member __.Entries = larr.Force()
member __.TryFindByName (nsp, nm) = match getmap().TryGetValue ((nsp, nm)) with true, v -> Some v | false, _ -> None

Expand Down Expand Up @@ -4579,34 +4586,29 @@ module internal AssemblyReader =

let mkCacheInt32 lowMem _infile _nm _sz =
if lowMem then (fun f x -> f x) else
let cache = ref null
let cache = Dictionary<int32, _>()
let syncObj = obj()
fun f (idx:int32) ->
let cache =
match !cache with
| null -> cache := new Dictionary<int32, _>(11)
| _ -> ()
!cache
let mutable res = Unchecked.defaultof<_>
let ok = cache.TryGetValue(idx, &res)
if ok then
res
else
let res = f idx
cache.[idx] <- res;
res
lock syncObj (fun () ->
let mutable res = Unchecked.defaultof<_>
if cache.TryGetValue(idx, &res) then res
else
let v = f idx
cache.[idx] <- v
v)

let mkCacheGeneric lowMem _inbase _nm _sz =
if lowMem then (fun f x -> f x) else
let cache = ref null
let cache = Dictionary<'T, _>()
let syncObj = obj()
fun f (idx :'T) ->
let cache =
match !cache with
| null -> cache := new Dictionary<_, _>(11 (* sz:int *) )
| _ -> ()
!cache
match cache.TryGetValue idx with
| true, cached -> cached
| false, _ -> let res = f idx in cache.[idx] <- res; res
lock syncObj (fun () ->
match cache.TryGetValue idx with
| true, cached -> cached
| false, _ ->
let v = f idx
cache.[idx] <- v
v)

let seekFindRow numRows rowChooser =
let mutable i = 1
Expand Down Expand Up @@ -7012,14 +7014,10 @@ namespace ProviderImplementation.ProvidedTypes
// Unique wrapped type definition objects must be translated to unique wrapper objects, based
// on object identity.
type TxTable<'T2>() =
let tab = Dictionary<int, 'T2>()
let tab = System.Collections.Concurrent.ConcurrentDictionary<int, Lazy<'T2>>()
member __.Get inp f =
match tab.TryGetValue inp with
| true, tabVal -> tabVal
| false, _ ->
let res = f()
tab.[inp] <- res
res
let lazyVal = tab.GetOrAdd(inp, fun _ -> lazy (f()))
lazyVal.Value

member __.ContainsKey inp = tab.ContainsKey inp

Expand Down
43 changes: 43 additions & 0 deletions tests/BasicGenerativeProvisionTests.fs
Original file line number Diff line number Diff line change
Expand Up @@ -511,3 +511,46 @@ let ``Generative custom attribute with named property argument encodes and round
Assert.Equal(1, namedArgs.Length)
Assert.Equal("Name", namedArgs.[0].MemberName)
Assert.Equal("MyProp", namedArgs.[0].TypedValue.Value :?> string)

[<Fact>]
let ``TargetTypeDefinition member-wrapper caches are thread-safe under parallel access``() =
// Regression test for https://github.com/fsprojects/FSharp.TypeProviders.SDK/issues/481
// PR #471 introduced lazy caches in TargetTypeDefinition. When multiple threads call
// GetConstructors/GetMethods/etc. concurrently on the same generated type the underlying
// shared caches must not corrupt. Run 8 parallel threads each interrogating every member
// kind on the same TargetTypeDefinition; if any internal collection races, the dictionaries
// will throw InvalidOperationException.
let runtimeAssemblyRefs = Targets.DotNetStandard20FSharpRefs()
let runtimeAssembly = runtimeAssemblyRefs.[0]
let cfg = Testing.MakeSimulatedTypeProviderConfig(__SOURCE_DIRECTORY__, runtimeAssembly, runtimeAssemblyRefs)
let staticArgs = [| box 5; box 6 |]
let tp = GenerativePropertyProviderWithStaticParams cfg :> TypeProviderForNamespaces
let providedNamespace = tp.Namespaces.[0]
let providedTypes = providedNamespace.GetTypes()
let providedType = providedTypes.[0]
let typeName = providedType.Name + (staticArgs |> Seq.map (fun s -> ",\"" + s.ToString() + "\"") |> Seq.reduce (+))
let t = (tp :> ITypeProvider).ApplyStaticArguments(providedType, [| typeName |], staticArgs)
let assemContents = (tp :> ITypeProvider).GetGeneratedAssemblyContents(t.Assembly)
let assem = tp.TargetContext.ReadRelatedAssembly(assemContents)
let typeName2 = providedType.Namespace + "." + typeName
let targetType = assem.GetType(typeName2)
Assert.NotNull(targetType)

let bf = BindingFlags.Public ||| BindingFlags.NonPublic ||| BindingFlags.Instance ||| BindingFlags.Static
let errors = System.Collections.Concurrent.ConcurrentBag<exn>()
let threads =
[| for _ in 1..8 ->
System.Threading.Thread(fun () ->
try
for _ in 1..50 do
targetType.GetConstructors(bf) |> ignore
targetType.GetMethods(bf) |> ignore
targetType.GetFields(bf) |> ignore
targetType.GetProperties(bf) |> ignore
targetType.GetEvents(bf) |> ignore
targetType.GetNestedTypes(bf) |> ignore
with ex ->
errors.Add(ex)) |]
for th in threads do th.Start()
for th in threads do th.Join()
Assert.True(errors.IsEmpty, sprintf "Thread-safety violations: %A" (errors |> Seq.toList))
Loading