Skip to content

Commit a04d026

Browse files
committed
Codegen overloaded LLVM intrinsics using their name
1 parent 6368fd5 commit a04d026

5 files changed

Lines changed: 147 additions & 5 deletions

File tree

compiler/rustc_codegen_llvm/src/intrinsic.rs

Lines changed: 114 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1237,12 +1237,123 @@ fn autocast<'ll>(
12371237
}
12381238
}
12391239

1240+
fn parse_integer(string: &mut &[u8]) -> Option<u64> {
1241+
let mut number = 0;
1242+
let mut position = 0;
1243+
while let Some(&digit @ b'0'..=b'9') = string.get(position) {
1244+
number = (10 * number) + (digit - b'0') as u64;
1245+
position += 1;
1246+
}
1247+
1248+
if position != number.checked_ilog10().unwrap_or(0) as usize + 1 {
1249+
return None;
1250+
}
1251+
1252+
*string = &string[position..];
1253+
Some(number)
1254+
}
1255+
1256+
fn strip_off_prefix(slice: &mut &[u8], prefix: &[u8]) -> bool {
1257+
slice.strip_prefix(prefix).map(|remainder| *slice = remainder).is_some()
1258+
}
1259+
1260+
fn demangle_type_str<'ll>(cx: &CodegenCx<'ll, '_>, slice: &mut &[u8]) -> Option<&'ll Type> {
1261+
Some(if strip_off_prefix(slice, b"isVoid") {
1262+
cx.type_void()
1263+
} else if strip_off_prefix(slice, b"f16") {
1264+
cx.type_f16()
1265+
} else if strip_off_prefix(slice, b"bf16") {
1266+
cx.type_bf16()
1267+
} else if strip_off_prefix(slice, b"f32") {
1268+
cx.type_f32()
1269+
} else if strip_off_prefix(slice, b"f64") {
1270+
cx.type_f64()
1271+
} else if strip_off_prefix(slice, b"f128") {
1272+
cx.type_f128()
1273+
} else if strip_off_prefix(slice, b"i") {
1274+
let width = parse_integer(slice)?;
1275+
cx.type_ix(width)
1276+
} else if strip_off_prefix(slice, b"p") {
1277+
let address_space = parse_integer(slice)?;
1278+
cx.type_ptr_ext(AddressSpace(address_space as u32))
1279+
} else if strip_off_prefix(slice, b"v") {
1280+
let length = parse_integer(slice)?;
1281+
let element_type = demangle_type_str(cx, slice)?;
1282+
cx.type_vector(element_type, length)
1283+
} else if strip_off_prefix(slice, b"nvx") {
1284+
let length = parse_integer(slice)?;
1285+
let element_type = demangle_type_str(cx, slice)?;
1286+
cx.type_scalable_vector(element_type, length)
1287+
} else if strip_off_prefix(slice, b"a") {
1288+
let length = parse_integer(slice)?;
1289+
let element_type = demangle_type_str(cx, slice)?;
1290+
cx.type_array(element_type, length)
1291+
} else if strip_off_prefix(slice, b"sl_") {
1292+
let mut elements = Vec::new();
1293+
1294+
loop {
1295+
if let Some(remainder) = slice.strip_prefix(b"s")
1296+
&& !remainder.starts_with(b"_")
1297+
&& !remainder.starts_with(b"l_")
1298+
{
1299+
*slice = remainder;
1300+
break cx.type_struct(&elements, true);
1301+
}
1302+
elements.push(demangle_type_str(cx, slice)?);
1303+
}
1304+
} else if strip_off_prefix(slice, b"f_") {
1305+
let return_type = demangle_type_str(cx, slice)?;
1306+
let mut arguments = Vec::new();
1307+
1308+
loop {
1309+
if let Some(remainder) = slice.strip_prefix(b"f")
1310+
&& !remainder.starts_with(b"_")
1311+
{
1312+
*slice = remainder;
1313+
break cx.type_func(&arguments, return_type);
1314+
}
1315+
if strip_off_prefix(slice, b"varargf") {
1316+
break cx.type_variadic_func(&arguments, return_type);
1317+
}
1318+
arguments.push(demangle_type_str(cx, slice)?);
1319+
}
1320+
} else {
1321+
return None;
1322+
})
1323+
}
1324+
1325+
fn parse_type_parameters<'ll, 'tcx>(
1326+
cx: &CodegenCx<'ll, 'tcx>,
1327+
intrinsic: llvm::Intrinsic,
1328+
name: &str,
1329+
) -> Option<Vec<&'ll Type>> {
1330+
let base_name: &'ll [u8] = intrinsic.base_name();
1331+
1332+
let slice = &mut name.as_bytes().strip_prefix(base_name).unwrap();
1333+
1334+
if !intrinsic.is_overloaded() {
1335+
return slice.is_empty().then(|| Vec::new());
1336+
}
1337+
1338+
let mut type_params = Vec::new();
1339+
1340+
while !slice.is_empty() {
1341+
if !strip_off_prefix(slice, b".") {
1342+
return None;
1343+
}
1344+
1345+
type_params.push(demangle_type_str(cx, slice)?);
1346+
}
1347+
1348+
Some(type_params)
1349+
}
1350+
12401351
fn intrinsic_fn<'ll, 'tcx>(
12411352
bx: &Builder<'_, 'll, 'tcx>,
12421353
name: &str,
12431354
rust_return_ty: &'ll Type,
12441355
rust_argument_tys: Vec<&'ll Type>,
1245-
instance: ty::Instance<'tcx>,
1356+
instance: Instance<'tcx>,
12461357
) -> &'ll Value {
12471358
let tcx = bx.tcx;
12481359

@@ -1268,10 +1379,9 @@ fn intrinsic_fn<'ll, 'tcx>(
12681379
}
12691380

12701381
if let Some(intrinsic) = intrinsic
1271-
&& !intrinsic.is_overloaded()
1382+
&& let Some(type_params) = parse_type_parameters(bx.cx, intrinsic, name)
12721383
{
1273-
// FIXME: also do this for overloaded intrinsics
1274-
let llfn = intrinsic.get_declaration(bx.llmod, &[]);
1384+
let llfn = intrinsic.get_declaration(bx.llmod, &type_params);
12751385
let llvm_fn_ty = bx.get_type_of_global(llfn);
12761386

12771387
let llvm_return_ty = bx.get_return_type(llvm_fn_ty);

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1124,6 +1124,10 @@ unsafe extern "C" {
11241124
NewFn: &mut Option<&'a Value>,
11251125
) -> bool;
11261126
pub(crate) fn LLVMRustIsTargetIntrinsic(ID: NonZero<c_uint>) -> bool;
1127+
pub(crate) fn LLVMRustIntrinsicGetBaseName(
1128+
ID: NonZero<c_uint>,
1129+
NameLength: &mut size_t,
1130+
) -> *const c_char;
11271131

11281132
// Operations on parameters
11291133
pub(crate) fn LLVMIsAArgument(Val: &Value) -> Option<&Value>;

compiler/rustc_codegen_llvm/src/llvm/mod.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
use std::ffi::{CStr, CString};
44
use std::num::NonZero;
5-
use std::ptr;
65
use std::string::FromUtf8Error;
6+
use std::{ptr, slice};
77

88
use libc::c_uint;
99
use rustc_abi::{AddressSpace, Align, Size, WrappingRange};
@@ -340,6 +340,12 @@ impl Intrinsic {
340340
LLVMGetIntrinsicDeclaration(llmod, self.id, type_params.as_ptr(), type_params.len())
341341
}
342342
}
343+
344+
pub(crate) fn base_name<'ll>(self) -> &'ll [u8] {
345+
let mut length = 0;
346+
let ptr = unsafe { LLVMRustIntrinsicGetBaseName(self.id, &mut length) };
347+
unsafe { slice::from_raw_parts(ptr.cast(), length) }
348+
}
343349
}
344350

345351
/// Safe wrapper for `LLVMSetValueName2` from a byte slice

compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1841,6 +1841,12 @@ extern "C" bool LLVMRustIsTargetIntrinsic(unsigned ID) {
18411841
return Intrinsic::isTargetIntrinsic(ID);
18421842
}
18431843

1844+
extern "C" const char* LLVMRustIntrinsicGetBaseName(unsigned ID, size_t *NameLength) {
1845+
auto baseName = Intrinsic::getBaseName(ID);
1846+
*NameLength = baseName.size();
1847+
return baseName.data();
1848+
}
1849+
18441850
// Statically assert that the fixed metadata kind IDs declared in
18451851
// `metadata_kind.rs` match the ones actually used by LLVM.
18461852
#define FIXED_MD_KIND(VARIANT, VALUE) \

tests/codegen-llvm/inject-autocast.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,20 @@ pub unsafe fn amx_autocast(m: u16, n: u16, k: u16, a: Tile, b: Tile, c: Tile) ->
103103
foo(m, n, k, a, b, c)
104104
}
105105

106+
// CHECK-LABEL: @overloaded_bf16_autocast
107+
#[no_mangle]
108+
pub unsafe fn overloaded_bf16_autocast(a: i16x8) -> i16x8 {
109+
extern "unadjusted" {
110+
#[link_name = "llvm.sqrt.v8bf16"]
111+
fn foo(a: i16x8) -> i16x8;
112+
}
113+
114+
// CHECK: [[A:%[0-9]+]] = bitcast <8 x i16> {{.*}} to <8 x bfloat>
115+
// CHECK: [[B:%[0-9]+]] = call <8 x bfloat> @llvm.sqrt.v8bf16(<8 x bfloat> [[A]])
116+
// CHECK: bitcast <8 x bfloat> [[B]] to <8 x i16>
117+
foo(a)
118+
}
119+
106120
// CHECK: declare { i32, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64>, <2 x i64> } @llvm.x86.encodekey128(i32, <2 x i64>)
107121

108122
// CHECK: declare { <2 x i1>, <2 x i1> } @llvm.x86.avx512.vp2intersect.q.128(<2 x i64>, <2 x i64>)
@@ -116,3 +130,5 @@ pub unsafe fn amx_autocast(m: u16, n: u16, k: u16, a: Tile, b: Tile, c: Tile) ->
116130
// CHECK: declare x86_amx @llvm.x86.cast.vector.to.tile.v1024i8(<1024 x i8>)
117131

118132
// CHECK: declare <1024 x i8> @llvm.x86.cast.tile.to.vector.v1024i8(x86_amx)
133+
134+
// CHECK: declare <8 x bfloat> @llvm.sqrt.v8bf16(<8 x bfloat>)

0 commit comments

Comments
 (0)