Skip to content

Commit 577e9a7

Browse files
authored
[WebAssembly] WASIP3 Library Call Thread Context Support (#175800)
The [WebAssembly Component Model](https://component-model.bytecodealliance.org/) has added support for [cooperative multithreading](WebAssembly/component-model#557). This has been implemented in the [Wasmtime engine](bytecodealliance/wasmtime#11751) and is part of the wider project of [WASI preview 3](https://wasi.dev/roadmap#upcoming-wasi-03-releases), which is currently tracked [here](https://github.com/orgs/bytecodealliance/projects/16). These changes require updating the way that `__stack_pointer` and `__tls_base` work purely for a new `wasm32-wasip3` target; other targets will not be touched. Specifically, rather than using a Wasm global for tracking the stack pointer and TLS base, the new [`context.get/set`](https://github.com/WebAssembly/component-model/blob/main/design/mvp/CanonicalABI.md#-canon-contextget) component model builtin functions will be used (the intention being that runtimes will need to aggressively optimize these calls into single load/stores). For justification on this choice rather than switching out the global at context-switch boundaries, see [this comment](WebAssembly/wasi-libc#691 (comment)) and [this comment](WebAssembly/wasi-libc#691 (comment)). This PR adds support for using library calls instead of globals for holding the stack pointer and TLS base. When used, this thread context ABI emits calls to `__wasm_{get,set}_{stack_pointer,tls_base}` when needed. These functions can then be implemented in `libc`. This is enabled only for the WASIp3 target. There is a temporary macro define for `__wasm_libcall_thread_context__` which can be removed once `wasi-libc` has fully migrated to the new ABI for the WASIp3 target.
1 parent fc60e08 commit 577e9a7

30 files changed

Lines changed: 472 additions & 84 deletions

clang/lib/Basic/Targets/WebAssembly.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@ void WebAssemblyTargetInfo::getTargetDefines(const LangOptions &Opts,
123123
Builder.defineMacro("__wasm_tail_call__");
124124
if (HasWideArithmetic)
125125
Builder.defineMacro("__wasm_wide_arithmetic__");
126+
if (HasLibcallThreadContext)
127+
Builder.defineMacro("__wasm_libcall_thread_context__");
126128
// Note that not all wasm features appear here. For example,
127129
// HasCompatctImports
128130

clang/lib/Basic/Targets/WebAssembly.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ class LLVM_LIBRARY_VISIBILITY WebAssemblyTargetInfo : public TargetInfo {
6868
bool HasExtendedConst = false;
6969
bool HasFP16 = false;
7070
bool HasGC = false;
71+
bool HasLibcallThreadContext = false;
7172
bool HasMultiMemory = false;
7273
bool HasMultivalue = false;
7374
bool HasMutableGlobals = false;
@@ -110,6 +111,8 @@ class LLVM_LIBRARY_VISIBILITY WebAssemblyTargetInfo : public TargetInfo {
110111
PtrDiffType = SignedLong;
111112
IntPtrType = SignedLong;
112113
}
114+
if (T.getOS() == llvm::Triple::WASIp3)
115+
HasLibcallThreadContext = true;
113116
}
114117

115118
StringRef getABI() const override;

clang/lib/Driver/ToolChains/WebAssembly.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,11 @@ static bool WantsPthread(const llvm::Triple &Triple, const ArgList &Args) {
8888
return WantsPthread;
8989
}
9090

91+
static bool WantsLibcallThreadContext(const llvm::Triple &Triple,
92+
const ArgList &Args) {
93+
return Triple.getOS() == llvm::Triple::WASIp3;
94+
}
95+
9196
void wasm::Linker::ConstructJob(Compilation &C, const JobAction &JA,
9297
const InputInfo &Output,
9398
const InputInfoList &Inputs,
@@ -169,6 +174,9 @@ void wasm::Linker::ConstructJob(Compilation &C, const JobAction &JA,
169174

170175
AddLinkerInputs(ToolChain, Inputs, Args, CmdArgs, JA);
171176

177+
if (WantsLibcallThreadContext(ToolChain.getTriple(), Args))
178+
CmdArgs.push_back("--libcall-thread-context");
179+
172180
if (WantsPthread(ToolChain.getTriple(), Args))
173181
CmdArgs.push_back("--shared-memory");
174182

clang/test/Preprocessor/wasm-target-features.c

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@
217217
// GENERIC-NOT: #define __wasm_simd128__ 1{{$}}
218218
// GENERIC-NOT: #define __wasm_tail_call__ 1{{$}}
219219
// GENERIC-NOT: #define __wasm_wide_arithmetic__ 1{{$}}
220+
// GENERIC-NOT: #define __wasm_libcall_thread_context__ 1{{$}}
220221

221222
// RUN: %clang -E -dM %s -o - 2>&1 \
222223
// RUN: -target wasm32-unknown-unknown -mcpu=bleeding-edge \
@@ -251,3 +252,12 @@
251252
// RUN: | FileCheck %s -check-prefix=BLEEDING-EDGE-NO-SIMD128
252253
//
253254
// BLEEDING-EDGE-NO-SIMD128-NOT: #define __wasm_simd128__ 1{{$}}
255+
256+
// RUN: %clang -E -dM %s -o - 2>&1 \
257+
// RUN: -target wasm32-wasip3 \
258+
// RUN: | FileCheck %s -check-prefix=LIBCALL-THREAD-CONTEXT
259+
// RUN: %clang -E -dM %s -o - 2>&1 \
260+
// RUN: -target wasm64-wasip3 \
261+
// RUN: | FileCheck %s -check-prefix=LIBCALL-THREAD-CONTEXT
262+
263+
// LIBCALL-THREAD-CONTEXT: #define __wasm_libcall_thread_context__ 1{{$}}

lld/test/wasm/stack-pointer-abi.s

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# RUN: llvm-mc -filetype=obj -triple=wasm32-unknown-unknown -o %t.o %s
2+
# RUN: wasm-ld --libcall-thread-context -o %t.libcall.wasm %t.o
3+
# RUN: obj2yaml %t.libcall.wasm | FileCheck %s --check-prefix=LIBCALL
4+
# RUN: wasm-ld -o %t.global.wasm %t.o
5+
# RUN: obj2yaml %t.global.wasm | FileCheck %s --check-prefix=GLOBAL
6+
7+
.globl _start
8+
_start:
9+
.functype _start () -> ()
10+
end_function
11+
12+
# LIBCALL: Name: __init_stack_pointer
13+
# GLOBAL: Name: __stack_pointer
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Test that linking object files with mismatched thread context ABIs fails with an error.
2+
# The presence of an import of __stack_pointer from the env module should be treated
3+
# as an indication that the global thread context ABI is being used.
4+
5+
# RUN: llvm-mc -filetype=obj -triple=wasm32-unknown-unknown -o %t.o %s
6+
# RUN: not wasm-ld --libcall-thread-context %t.o -o %t.wasm 2>&1 | FileCheck %s
7+
8+
# CHECK: object file uses globals for thread context, but --libcall-thread-context was specified
9+
10+
.globl _start
11+
_start:
12+
.functype _start () -> ()
13+
end_function
14+
15+
.globaltype __stack_pointer, i32
16+
17+
.globl use_stack_pointer
18+
use_stack_pointer:
19+
.functype use_stack_pointer () -> ()
20+
global.get __stack_pointer
21+
drop
22+
end_function

lld/test/wasm/tls-libcall.s

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# RUN: llvm-mc -filetype=obj -triple=wasm32-unknown-unknown -o %t.o %s
2+
# RUN: wasm-ld --libcall-thread-context --shared-memory -no-gc-sections -o %t.wasm %t.o
3+
# RUN: obj2yaml %t.wasm | FileCheck %s
4+
# RUN: llvm-objdump -d --no-print-imm-hex --no-show-raw-insn %t.wasm | FileCheck %s --check-prefix=DIS
5+
6+
.globl __wasm_get_tls_base
7+
__wasm_get_tls_base:
8+
.functype __wasm_get_tls_base () -> (i32)
9+
i32.const 0
10+
end_function
11+
12+
.globl _start
13+
_start:
14+
.functype _start () -> (i32)
15+
call __wasm_get_tls_base
16+
i32.const tls1@TLSREL
17+
i32.add
18+
i32.load 0
19+
call __wasm_get_tls_base
20+
i32.const tls2@TLSREL
21+
i32.add
22+
i32.load 0
23+
i32.add
24+
end_function
25+
26+
.section .tdata.tls1,"",@
27+
.globl tls1
28+
tls1:
29+
.int32 1
30+
.size tls1, 4
31+
32+
.section .tdata.tls2,"",@
33+
.globl tls2
34+
tls2:
35+
.int32 2
36+
.size tls2, 4
37+
38+
.section .custom_section.target_features,"",@
39+
.int8 2
40+
.int8 43
41+
.int8 11
42+
.ascii "bulk-memory"
43+
.int8 43
44+
.int8 7
45+
.ascii "atomics"
46+
47+
48+
# CHECK: GlobalNames:
49+
# CHECK-NEXT: - Index: 0
50+
# CHECK-NEXT: Name: __init_stack_pointer
51+
# CHECK-NEXT: - Index: 1
52+
# CHECK-NEXT: Name: __init_tls_base
53+
# CHECK-NEXT: - Index: 2
54+
# CHECK-NEXT: Name: __tls_size
55+
# CHECK-NEXT: - Index: 3
56+
# CHECK-NEXT: Name: __tls_align
57+
58+
# DIS-LABEL: <__wasm_init_memory>:
59+
60+
# DIS-LABEL: <_start>:
61+
# DIS-EMPTY:
62+
# DIS-NEXT: call 4
63+
# DIS-NEXT: i32.const 0
64+
# DIS-NEXT: i32.add
65+
# DIS-NEXT: i32.load 0
66+
# DIS-NEXT: call 4
67+
# DIS-NEXT: i32.const 4
68+
# DIS-NEXT: i32.add
69+
# DIS-NEXT: i32.load 0
70+
# DIS-NEXT: i32.add
71+
# DIS-NEXT: end

lld/wasm/Config.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class Symbol;
3535
class DefinedData;
3636
class GlobalSymbol;
3737
class DefinedFunction;
38+
class UndefinedFunction;
3839
class DefinedGlobal;
3940
class UndefinedGlobal;
4041
class TableSymbol;
@@ -64,6 +65,7 @@ struct Config {
6465
bool growableTable;
6566
bool gcSections;
6667
llvm::StringSet<> keepSections;
68+
bool libcallThreadContext;
6769
std::optional<std::pair<llvm::StringRef, llvm::StringRef>> memoryImport;
6870
std::optional<llvm::StringRef> memoryExport;
6971
bool sharedMemory;
@@ -252,6 +254,14 @@ struct Ctx {
252254
// Used as an address space for function pointers, with each function that
253255
// is used as a function pointer being allocated a slot.
254256
TableSymbol *indirectFunctionTable;
257+
258+
// __wasm_set_tls_base
259+
// Function used to set TLS base in libcall thread context modules.
260+
UndefinedFunction *setTLSBase;
261+
262+
// __wasm_get_tls_base
263+
// Function used to get TLS base in libcall thread context modules.
264+
UndefinedFunction *getTLSBase;
255265
};
256266
WasmSym sym;
257267

lld/wasm/Driver.cpp

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,7 @@ static void readConfigs(opt::InputArgList &args) {
561561
ctx.arg.soName = args.getLastArgValue(OPT_soname);
562562
ctx.arg.importTable = args.hasArg(OPT_import_table);
563563
ctx.arg.importUndefined = args.hasArg(OPT_import_undefined);
564+
ctx.arg.libcallThreadContext = args.hasArg(OPT_libcall_thread_context);
564565
ctx.arg.ltoo = args::getInteger(args, OPT_lto_O, 2);
565566
if (ctx.arg.ltoo > 3)
566567
error("invalid optimization level for LTO: " + Twine(ctx.arg.ltoo));
@@ -883,6 +884,16 @@ createUndefinedGlobal(StringRef name, llvm::wasm::WasmGlobalType *type) {
883884
return sym;
884885
}
885886

887+
static UndefinedFunction *createUndefinedFunction(StringRef name,
888+
WasmSignature *signature) {
889+
auto *sym = cast<UndefinedFunction>(symtab->addUndefinedFunction(
890+
name, std::nullopt, std::nullopt, WASM_SYMBOL_UNDEFINED, nullptr,
891+
signature, true));
892+
ctx.arg.allowUndefinedSymbols.insert(sym->getName());
893+
sym->isUsedInRegularObj = true;
894+
return sym;
895+
}
896+
886897
static InputGlobal *createGlobal(StringRef name, bool isMutable) {
887898
llvm::wasm::WasmGlobal wasmGlobal;
888899
bool is64 = ctx.arg.is64.value_or(false);
@@ -917,17 +928,26 @@ static void createSyntheticSymbols() {
917928
true};
918929
static llvm::wasm::WasmGlobalType mutableGlobalTypeI64 = {WASM_TYPE_I64,
919930
true};
931+
920932
ctx.sym.callCtors = symtab->addSyntheticFunction(
921933
"__wasm_call_ctors", WASM_SYMBOL_VISIBILITY_HIDDEN,
922934
make<SyntheticFunction>(nullSignature, "__wasm_call_ctors"));
923935

924936
bool is64 = ctx.arg.is64.value_or(false);
925937

938+
auto stack_pointer_name =
939+
ctx.arg.libcallThreadContext ? "__init_stack_pointer" : "__stack_pointer";
926940
if (ctx.isPic) {
927-
ctx.sym.stackPointer =
928-
createUndefinedGlobal("__stack_pointer", ctx.arg.is64.value_or(false)
929-
? &mutableGlobalTypeI64
930-
: &mutableGlobalTypeI32);
941+
if (ctx.arg.libcallThreadContext) {
942+
ctx.sym.stackPointer = createUndefinedGlobal(
943+
stack_pointer_name,
944+
ctx.arg.is64.value_or(false) ? &globalTypeI64 : &globalTypeI32);
945+
} else {
946+
ctx.sym.stackPointer = createUndefinedGlobal(stack_pointer_name,
947+
ctx.arg.is64.value_or(false)
948+
? &mutableGlobalTypeI64
949+
: &mutableGlobalTypeI32);
950+
}
931951
// For PIC code, we import two global variables (__memory_base and
932952
// __table_base) from the environment and use these as the offset at
933953
// which to load our static data and function table.
@@ -940,14 +960,18 @@ static void createSyntheticSymbols() {
940960
ctx.sym.tableBase->markLive();
941961
} else {
942962
// For non-PIC code
943-
ctx.sym.stackPointer = createGlobalVariable("__stack_pointer", true);
963+
ctx.sym.stackPointer =
964+
createGlobalVariable(stack_pointer_name, !ctx.arg.libcallThreadContext);
944965
ctx.sym.stackPointer->markLive();
945966
}
946967

947968
if (ctx.arg.sharedMemory) {
948969
// TLS symbols are all hidden/dso-local
970+
auto tls_base_name =
971+
ctx.arg.libcallThreadContext ? "__init_tls_base" : "__tls_base";
949972
ctx.sym.tlsBase =
950-
createGlobalVariable("__tls_base", true, WASM_SYMBOL_VISIBILITY_HIDDEN);
973+
createGlobalVariable(tls_base_name, !ctx.arg.libcallThreadContext,
974+
WASM_SYMBOL_VISIBILITY_HIDDEN);
951975
ctx.sym.tlsSize = createGlobalVariable("__tls_size", false,
952976
WASM_SYMBOL_VISIBILITY_HIDDEN);
953977
ctx.sym.tlsAlign = createGlobalVariable("__tls_align", false,
@@ -956,6 +980,17 @@ static void createSyntheticSymbols() {
956980
"__wasm_init_tls", WASM_SYMBOL_VISIBILITY_HIDDEN,
957981
make<SyntheticFunction>(is64 ? i64ArgSignature : i32ArgSignature,
958982
"__wasm_init_tls"));
983+
if (ctx.arg.libcallThreadContext) {
984+
ctx.sym.tlsBase->markLive();
985+
ctx.sym.tlsSize->markLive();
986+
ctx.sym.tlsAlign->markLive();
987+
static WasmSignature setTLSBaseSignature{{}, {ValType::I32}};
988+
ctx.sym.setTLSBase =
989+
createUndefinedFunction("__wasm_set_tls_base", &setTLSBaseSignature);
990+
static WasmSignature getTLSBaseSignature{{ValType::I32}, {}};
991+
ctx.sym.getTLSBase =
992+
createUndefinedFunction("__wasm_get_tls_base", &getTLSBaseSignature);
993+
}
959994
}
960995
}
961996

lld/wasm/Options.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,9 @@ def page_size: JJ<"page-size=">,
238238
def initial_memory: JJ<"initial-memory=">,
239239
HelpText<"Initial size of the linear memory">;
240240

241+
def libcall_thread_context: FF<"libcall-thread-context">,
242+
HelpText<"Use library calls for thread context access instead of globals.">;
243+
241244
def max_memory: JJ<"max-memory=">,
242245
HelpText<"Maximum size of the linear memory">;
243246

0 commit comments

Comments
 (0)