Skip to content

Commit 732272f

Browse files
committed
[GR-76268] Guard thread state map access
1 parent 89010bc commit 732272f

1 file changed

Lines changed: 35 additions & 21 deletions

File tree

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/runtime/PythonContext.java

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@
8787
import java.text.MessageFormat;
8888
import java.util.ArrayDeque;
8989
import java.util.ArrayList;
90-
import java.util.Collections;
9190
import java.util.HashMap;
9291
import java.util.HashSet;
9392
import java.util.LinkedList;
@@ -782,7 +781,7 @@ static PythonThreadState getThreadState(Node n) {
782781
@CompilationFinal private TruffleLanguage.Env env;
783782

784783
/* map of Python threads' IDs to the corresponding 'threadStates' */
785-
private final Map<Thread, PythonThreadState> threadStateMapping = Collections.synchronizedMap(new WeakHashMap<>());
784+
private final Map<Thread, PythonThreadState> threadStateMapping = new WeakHashMap<>();
786785
private WeakReference<Thread> mainThread;
787786

788787
/* List of non-Python level threads. Those threads will be joined in finalizeContext. */
@@ -2277,10 +2276,12 @@ public void runShutdownHooks() {
22772276
@TruffleBoundary
22782277
private void disposeThreadStates() {
22792278
Thread currentThread = Thread.currentThread();
2280-
for (Map.Entry<Thread, PythonThreadState> entry : threadStateMapping.entrySet()) {
2281-
entry.getValue().dispose(true, entry.getKey() == currentThread);
2279+
synchronized (this) {
2280+
for (Map.Entry<Thread, PythonThreadState> entry : threadStateMapping.entrySet()) {
2281+
entry.getValue().dispose(true, entry.getKey() == currentThread);
2282+
}
2283+
threadStateMapping.clear();
22822284
}
2283-
threadStateMapping.clear();
22842285
}
22852286

22862287
/**
@@ -2338,7 +2339,10 @@ private void joinPythonThreads() {
23382339
// make a copy of the threads, because the threads will disappear one by one from the
23392340
// threadStateMapping as we're joining them, which gives undefined results for the
23402341
// iterator over keySet
2341-
LinkedList<Thread> threads = new LinkedList<>(threadStateMapping.keySet());
2342+
LinkedList<Thread> threads;
2343+
synchronized (this) {
2344+
threads = new LinkedList<>(threadStateMapping.keySet());
2345+
}
23422346
boolean runViaLauncher = getOption(PythonOptions.RunViaLauncher);
23432347
for (Thread thread : threads) {
23442348
if (thread != Thread.currentThread()) {
@@ -2704,7 +2708,9 @@ public void popCurrentImport() {
27042708

27052709
public Thread[] getThreads() {
27062710
CompilerAsserts.neverPartOfCompilation();
2707-
return threadStateMapping.keySet().toArray(new Thread[0]);
2711+
synchronized (this) {
2712+
return threadStateMapping.keySet().toArray(new Thread[0]);
2713+
}
27082714
}
27092715

27102716
public PythonThreadState getThreadState(PythonLanguage lang) {
@@ -2753,7 +2759,10 @@ public void initializeMultiThreading() {
27532759
public void attachThread(Thread thread, ContextThreadLocal<PythonThreadState> threadState) {
27542760
CompilerAsserts.neverPartOfCompilation();
27552761
PythonThreadState pythonThreadState = threadState.get(thread);
2756-
PythonThreadState previousThreadState = threadStateMapping.put(thread, pythonThreadState);
2762+
PythonThreadState previousThreadState;
2763+
synchronized (this) {
2764+
previousThreadState = threadStateMapping.put(thread, pythonThreadState);
2765+
}
27572766
ReentrantLock initLock = getcApiInitializationLock();
27582767
/*
27592768
* Synchronize with C API initialization so that we do not miss eager initialization of this
@@ -2772,10 +2781,12 @@ public void attachThread(Thread thread, ContextThreadLocal<PythonThreadState> th
27722781
initializeNativeThreadState(pythonThreadState);
27732782
}
27742783
} catch (PException e) {
2775-
if (previousThreadState == null) {
2776-
threadStateMapping.remove(thread);
2777-
} else {
2778-
threadStateMapping.put(thread, previousThreadState);
2784+
synchronized (this) {
2785+
if (previousThreadState == null) {
2786+
threadStateMapping.remove(thread);
2787+
} else {
2788+
threadStateMapping.put(thread, previousThreadState);
2789+
}
27792790
}
27802791
throw e;
27812792
} finally {
@@ -2819,16 +2830,19 @@ public void disposeThread(Thread thread, boolean canRunGuestCode) {
28192830
*/
28202831
public void disposeThread(Thread thread, boolean canRunGuestCode, boolean markShuttingDown) {
28212832
CompilerAsserts.neverPartOfCompilation();
2822-
// check if there is a live sentinel lock
2823-
PythonThreadState ts = threadStateMapping.get(thread);
2824-
if (ts == null) {
2825-
// ts already removed, that is valid during context shutdown for daemon threads
2826-
return;
2827-
}
2828-
if (markShuttingDown) {
2829-
ts.shutdown();
2833+
PythonThreadState ts;
2834+
synchronized (this) {
2835+
// check if there is a live sentinel lock
2836+
ts = threadStateMapping.get(thread);
2837+
if (ts == null) {
2838+
// ts already removed, that is valid during context shutdown for daemon threads
2839+
return;
2840+
}
2841+
if (markShuttingDown) {
2842+
ts.shutdown();
2843+
}
2844+
threadStateMapping.remove(thread);
28302845
}
2831-
threadStateMapping.remove(thread);
28322846
ts.dispose(thread == Thread.currentThread(), markShuttingDown);
28332847
releaseSentinelLock(ts.sentinelLock);
28342848
getSharedMultiprocessingData().removeChildContextThread(PThread.getThreadId(thread));

0 commit comments

Comments
 (0)