Skip to content

Commit 7976aab

Browse files
committed
Add LLVM JIT support for reference parameters
1 parent e11a7db commit 7976aab

2 files changed

Lines changed: 58 additions & 21 deletions

File tree

src/llvm_jit.rs

Lines changed: 58 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,7 @@ impl crate::Type {
378378
crate::Type::Name(_, _) => ptr_ty.into(),
379379
crate::Type::Array(_, _) => ptr_ty.into(),
380380
crate::Type::Slice(_) => ptr_ty.into(),
381+
crate::Type::Reference(_) => ptr_ty.into(),
381382
crate::Type::Tuple(_) => ptr_ty.into(),
382383
crate::Type::Anon(_) => panic!("anonymous type in codegen"),
383384
crate::Type::Var(_) => panic!("type var in codegen"),
@@ -860,10 +861,15 @@ impl<'ctx> LLVMJITState<'ctx> {
860861
for (i, param) in decl.params.iter().enumerate() {
861862
let param_val = params[param_idx + i];
862863
let ty = param.ty.expect("param ty");
864+
let storage_ty = if let crate::Type::Reference(inner) = &*ty {
865+
*inner
866+
} else {
867+
ty
868+
};
863869

864870
// Slice parameters get noalias: the type checker enforces that no two
865871
// slice arguments in a call can alias (Fortran-style restrict).
866-
if matches!(&*ty, crate::Type::Slice(_)) {
872+
if matches!(&*ty, crate::Type::Slice(_) | crate::Type::Reference(_)) {
867873
let noalias = self.context.create_enum_attribute(
868874
inkwell::attributes::Attribute::get_named_enum_kind_id("noalias"),
869875
0,
@@ -881,8 +887,10 @@ impl<'ctx> LLVMJITState<'ctx> {
881887
.unwrap();
882888
self.builder.build_store(alloca, param_val).unwrap();
883889
variables.insert(param.name.to_string(), alloca);
884-
variable_types.insert(param.name.to_string(), ty);
885-
let_bindings.insert(param.name.to_string());
890+
variable_types.insert(param.name.to_string(), storage_ty);
891+
if !matches!(&*ty, crate::Type::Reference(_)) {
892+
let_bindings.insert(param.name.to_string());
893+
}
886894
}
887895

888896
let mut trans = FunctionTranslator {
@@ -1336,7 +1344,7 @@ impl<'a, 'ctx> FunctionTranslator<'a, 'ctx> {
13361344
})
13371345
.unwrap_or(decl.types[expr]),
13381346
Expr::ArrayIndex(arr_id, _) => match &*self.representation_type(*arr_id, decl) {
1339-
crate::Type::Array(elem, _) | crate::Type::Slice(elem) => *elem,
1347+
crate::Type::Array(elem, _) | crate::Type::Slice(elem) | crate::Type::Reference(elem) => *elem,
13401348
_ => decl.types[expr],
13411349
},
13421350
_ => decl.types[expr],
@@ -1510,10 +1518,24 @@ impl<'a, 'ctx> FunctionTranslator<'a, 'ctx> {
15101518
alloca
15111519
}
15121520
} else {
1513-
self.builder()
1514-
.build_load(self.ptr_ty(), alloca, "var_addr")
1515-
.unwrap()
1516-
.into_pointer_value()
1521+
if let Some(ty) = self.variable_types.get(name.as_str()).copied() {
1522+
if ty.is_ptr() {
1523+
self.builder()
1524+
.build_load(ty.llvm_basic_type(self.ctx()), alloca, "var_addr")
1525+
.unwrap()
1526+
.into_pointer_value()
1527+
} else {
1528+
self.builder()
1529+
.build_load(self.ptr_ty(), alloca, "var_addr")
1530+
.unwrap()
1531+
.into_pointer_value()
1532+
}
1533+
} else {
1534+
self.builder()
1535+
.build_load(self.ptr_ty(), alloca, "var_addr")
1536+
.unwrap()
1537+
.into_pointer_value()
1538+
}
15171539
}
15181540
} else if let Some(&offset) = self.state.globals.get(name) {
15191541
self.ptr_at_offset(self.globals_base, offset as u64)
@@ -1626,15 +1648,20 @@ impl<'a, 'ctx> FunctionTranslator<'a, 'ctx> {
16261648
.build_load(ty.llvm_basic_type(self.ctx()), alloca, &**name)
16271649
.unwrap()
16281650
} else {
1629-
// var binding: alloca holds a pointer to the slot; load ptr, then load value.
1630-
let slot_ptr = self
1631-
.builder()
1632-
.build_load(self.ptr_ty(), alloca, "var_ptr")
1633-
.unwrap()
1634-
.into_pointer_value();
1635-
self.builder()
1636-
.build_load(ty.llvm_basic_type(self.ctx()), slot_ptr, &**name)
1637-
.unwrap()
1651+
let stored = self.builder().build_load(self.ptr_ty(), alloca, "var_ptr").unwrap();
1652+
if let Some(var_ty) = self.variable_types.get(name.as_str()).copied() {
1653+
if var_ty.is_ptr() {
1654+
stored
1655+
} else {
1656+
self.builder()
1657+
.build_load(ty.llvm_basic_type(self.ctx()), stored.into_pointer_value(), &**name)
1658+
.unwrap()
1659+
}
1660+
} else {
1661+
self.builder()
1662+
.build_load(ty.llvm_basic_type(self.ctx()), stored.into_pointer_value(), &**name)
1663+
.unwrap()
1664+
}
16381665
}
16391666
} else if let Some(&offset) = self.state.globals.get(name) {
16401667
let addr = self.ptr_at_offset(self.globals_base, offset as u64);
@@ -2951,8 +2978,12 @@ impl<'a, 'ctx> FunctionTranslator<'a, 'ctx> {
29512978

29522979
let mut args: Vec<BasicMetadataValueEnum<'ctx>> = vec![context.into()];
29532980
for (i, arg_id) in arg_ids.iter().enumerate() {
2954-
let arg_val = self.translate_expr(*arg_id, decl);
29552981
let param_ty = callee_decl.params[i].ty.unwrap();
2982+
let arg_val = if matches!(&*param_ty, crate::Type::Reference(_)) {
2983+
self.translate_lvalue(*arg_id, decl).into()
2984+
} else {
2985+
self.translate_expr(*arg_id, decl)
2986+
};
29562987
if matches!(&*param_ty, crate::Type::Slice(_)) {
29572988
let actual_ty = self.representation_type(*arg_id, decl);
29582989
match &*actual_ty {
@@ -3081,8 +3112,15 @@ impl<'a, 'ctx> FunctionTranslator<'a, 'ctx> {
30813112
let full_param_types = full_fn_ty.get_param_types();
30823113
let arg_start = args.len(); // offset for globals+closure+output_ptr
30833114
for (i, arg_id) in arg_ids.iter().enumerate() {
3084-
let arg_val = self.translate_expr(*arg_id, decl);
3085-
if i < param_types.len() && is_slice(param_types[i]) {
3115+
let expected_ty = param_types.get(i).copied();
3116+
let arg_val = if expected_ty
3117+
.is_some_and(|ty| matches!(&*ty, crate::Type::Reference(_)))
3118+
{
3119+
self.translate_lvalue(*arg_id, decl).into()
3120+
} else {
3121+
self.translate_expr(*arg_id, decl)
3122+
};
3123+
if expected_ty.is_some_and(is_slice) {
30863124
let wrapped = self.wrap_as_slice(
30873125
arg_val.into_pointer_value(),
30883126
self.representation_type(*arg_id, decl),

tests/cases/references/ref_param_stack.lyte

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
// skip-backend: asm
2-
// skip-backend: llvm
32
// expected stdout:
43
// compilation successful
54
// 42

0 commit comments

Comments
 (0)