Skip to content

Commit 5b0d49e

Browse files
Kenotopolarityaviatesk
authored
Co-authored-by: Cody Tapscott <topolarity@tapscott.me> Co-authored-by: Shuhei Kadowaki <aviatesk@gmail.com>
1 parent b98d972 commit 5b0d49e

8 files changed

Lines changed: 125 additions & 66 deletions

File tree

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Cthulhu"
22
uuid = "f68482b8-f384-11e8-15f7-abe071a5a75f"
33
authors = ["Valentin Churavy <v.churavy@gmail.com> and contributors"]
4-
version = "2.11.1"
4+
version = "2.12.0"
55

66
[deps]
77
CodeTracking = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2"
@@ -24,7 +24,7 @@ JuliaSyntax = "0.4"
2424
PrecompileTools = "1"
2525
Preferences = "1"
2626
REPL = "1.9"
27-
TypedSyntax = "1.2.2"
27+
TypedSyntax = "1.3.0"
2828
UUIDs = "1.9"
2929
Unicode = "1.9"
3030
WidthLimitedIO = "1"

TypedSyntax/src/node.jl

Lines changed: 73 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -24,22 +24,23 @@ const no_default_value = NoDefaultValue()
2424
# These are TypedSyntaxNode constructor helpers
2525
# Call these directly if you want both the TypedSyntaxNode and the `mappings` list,
2626
# where `mappings[i]` corresponds to the list of nodes matching `(src::CodeInfo).code[i]`.
27-
function tsn_and_mappings(@nospecialize(f), @nospecialize(t); kwargs...)
28-
m = which(f, t)
29-
src, rt = getsrc(f, t)
30-
tsn_and_mappings(m, src, rt; kwargs...)
27+
function tsn_and_mappings(@nospecialize(f), @nospecialize(tt=Base.default_tt(f)); kwargs...)
28+
inferred_result = get_inferred_result(f, tt)
29+
return tsn_and_mappings(inferred_result.mi, inferred_result.src, inferred_result.rt; kwargs...)
3130
end
3231

33-
function tsn_and_mappings(m::Method, src::CodeInfo, @nospecialize(rt); warn::Bool=true, strip_macros::Bool=false, kwargs...)
32+
function tsn_and_mappings(mi::MethodInstance, src::CodeInfo, @nospecialize(rt); warn::Bool=true, strip_macros::Bool=false, kwargs...)
33+
m = mi.def::Method
3434
def = definition(String, m)
3535
if isnothing(def)
3636
warn && @warn "couldn't retrieve source of $m"
3737
return nothing, nothing
3838
end
39-
return tsn_and_mappings(m, src, rt, def...; warn, strip_macros, kwargs...)
39+
return tsn_and_mappings(mi, src, rt, def...; warn, strip_macros, kwargs...)
4040
end
4141

42-
function tsn_and_mappings(m::Method, src::CodeInfo, @nospecialize(rt), sourcetext::AbstractString, lineno::Integer; warn::Bool=true, strip_macros::Bool=false, kwargs...)
42+
function tsn_and_mappings(mi::MethodInstance, src::CodeInfo, @nospecialize(rt), sourcetext::AbstractString, lineno::Integer; warn::Bool=true, strip_macros::Bool=false, kwargs...)
43+
m = mi.def::Method
4344
filename = isnothing(functionloc(m)[1]) ? string(m.file) : functionloc(m)[1]
4445
rootnode = JuliaSyntax.parsestmt(SyntaxNode, sourcetext; filename=filename, first_line=lineno, kwargs...)
4546
if strip_macros
@@ -50,22 +51,26 @@ function tsn_and_mappings(m::Method, src::CodeInfo, @nospecialize(rt), sourcetex
5051
end
5152
end
5253
Δline = lineno - m.line # offset from original line number (Revise)
53-
mappings, symtyps = map_ssas_to_source(src, rootnode, Δline)
54+
mappings, symtyps = map_ssas_to_source(src, mi, rootnode, Δline)
5455
node = TypedSyntaxNode(rootnode, src, mappings, symtyps)
5556
node.typ = rt
5657
return node, mappings
5758
end
5859

59-
TypedSyntaxNode(@nospecialize(f), @nospecialize(t); kwargs...) = tsn_and_mappings(f, t; kwargs...)[1]
60+
TypedSyntaxNode(@nospecialize(f), @nospecialize(tt=Base.default_tt(f)); kwargs...) = tsn_and_mappings(f, tt; kwargs...)[1]
6061

6162
function TypedSyntaxNode(mi::MethodInstance; kwargs...)
62-
m = mi.def::Method
63-
src, rt = getsrc(mi)
64-
tsn_and_mappings(m, src, rt; kwargs...)[1]
63+
src, rt = code_typed1_tsn(mi)
64+
tsn_and_mappings(mi, src, rt; kwargs...)[1]
65+
end
66+
67+
function TypedSyntaxNode(rootnode::SyntaxNode, @nospecialize(f), @nospecialize(tt=Base.default_tt(f)); kwargs...)
68+
inferred_result = get_inferred_result(f, tt)
69+
TypedSyntaxNode(rootnode, inferred_result.src, inferred_result.mi; kwargs...)
6570
end
6671

67-
TypedSyntaxNode(rootnode::SyntaxNode, src::CodeInfo, Δline::Integer=0) =
68-
TypedSyntaxNode(rootnode, src, map_ssas_to_source(src, rootnode, Δline)...)
72+
TypedSyntaxNode(rootnode::SyntaxNode, src::CodeInfo, mi::MethodInstance, Δline::Integer=0) =
73+
TypedSyntaxNode(rootnode, src, map_ssas_to_source(src, mi, rootnode, Δline)...)
6974

7075
function TypedSyntaxNode(rootnode::SyntaxNode, src::CodeInfo, mappings, symtyps)
7176
# There may be ambiguous assignments back to the source; preserve just the unambiguous ones
@@ -304,17 +309,57 @@ function sparam_name(mi::MethodInstance, i::Int)
304309
return sig.var.name
305310
end
306311

307-
function getsrc(@nospecialize(f), @nospecialize(t))
308-
srcrts = code_typed(f, t; debuginfo=:source, optimize=false)
309-
return only(srcrts)
310-
end
311-
312-
function getsrc(mi::MethodInstance)
313-
cis = Base.code_typed_by_type(mi.specTypes; debuginfo=:source, optimize=false)
314-
isempty(cis) && error("no applicable type-inferred code found for ", mi)
315-
length(cis) == 1 || error("got $(length(cis)) possible type-inferred results for ", mi,
316-
", you may need a more specialized signature")
317-
return cis[1]::Pair{CodeInfo}
312+
@static if isdefined(Base, :method_instances)
313+
using Base: method_instances
314+
else
315+
function method_instances(@nospecialize(f), @nospecialize(t), world::UInt)
316+
tt = Base.signature_type(f, t)
317+
results = Core.MethodInstance[]
318+
# this make a better error message than the typeassert that follows
319+
world == typemax(UInt) && error("code reflection cannot be used from generated functions")
320+
for match in Base._methods_by_ftype(tt, -1, world)::Vector
321+
instance = Core.Compiler.specialize_method(match)
322+
push!(results, instance)
323+
end
324+
return results
325+
end
326+
end
327+
328+
struct InferredResult
329+
mi::MethodInstance
330+
src::CodeInfo
331+
rt
332+
InferredResult(mi::MethodInstance, src::CodeInfo, @nospecialize(rt)) = new(mi, src, rt)
333+
end
334+
function get_inferred_result(@nospecialize(f), @nospecialize(tt=Base.default_tt(f)),
335+
world::UInt=Base.get_world_counter())
336+
mis = method_instances(f, tt, world)
337+
if isempty(mis)
338+
sig = sprint(Base.show_tuple_as_call, Symbol(""), Base.signature_type(f, tt))
339+
error("no applicable type-inferred code found for ", sig)
340+
elseif length(mis) 1
341+
sig = sprint(Base.show_tuple_as_call, Symbol(""), Base.signature_type(f, tt))
342+
error("got $(length(mis)) possible type-inferred results for ", sig,
343+
", you may need a more specialized signature")
344+
end
345+
mi = only(mis)
346+
return InferredResult(mi, code_typed1_tsn(mi)...)
347+
end
348+
349+
code_typed1_tsn(mi::MethodInstance) = code_typed1_by_method_instance(mi; optimize=false, debuginfo=:source)
350+
351+
function code_typed1_by_method_instance(mi::MethodInstance;
352+
optimize::Bool=true,
353+
debuginfo::Symbol=:default,
354+
world::UInt=Base.get_world_counter(),
355+
interp::Core.Compiler.AbstractInterpreter=Core.Compiler.NativeInterpreter(world))
356+
(ccall(:jl_is_in_pure_context, Bool, ()) || world == typemax(UInt)) &&
357+
error("code reflection should not be used from generated functions")
358+
debuginfo = Base.IRShow.debuginfo(debuginfo)
359+
code, rt = Core.Compiler.typeinf_code(interp, mi.def::Method, mi.specTypes, mi.sparam_vals, optimize)
360+
code isa CodeInfo || error("no code is available for ", mi)
361+
debuginfo === :none && Base.remove_linenums!(code)
362+
return Pair{CodeInfo,Any}(code, rt)
318363
end
319364

320365
function is_function_def(node) # this is not `Base.is_function_def`
@@ -397,8 +442,7 @@ end
397442
# Main logic for mapping `src.code[i]` to node(s) in the SyntaxNode tree
398443
# Success: when we map it to a unique node
399444
# Δline is the (Revise) offset of the line number
400-
function map_ssas_to_source(src::CodeInfo, rootnode::SyntaxNode, Δline::Int)
401-
mi = src.parent::MethodInstance
445+
function map_ssas_to_source(src::CodeInfo, mi::MethodInstance, rootnode::SyntaxNode, Δline::Int)
402446
slottypes = src.slottypes::Union{Nothing, Vector{Any}}
403447
have_slottypes = slottypes !== nothing
404448
ssavaluetypes = src.ssavaluetypes::Vector{Any}
@@ -428,7 +472,7 @@ function map_ssas_to_source(src::CodeInfo, rootnode::SyntaxNode, Δline::Int)
428472
# (Essentially `copy!(mapped, filter(predicate, targets))`)
429473
function append_targets_for_line!(mapped#=::Vector{nodes}=#, i::Int, targets#=::Vector{nodes}=#)
430474
j = src.codelocs[i]
431-
lt = src.linetable::Vector{Any}
475+
lt = src.linetable::Vector
432476
start = getline(lt, j) + Δline
433477
stop = getnextline(lt, j, Δline) - 1
434478
linerange = start : stop
@@ -736,7 +780,7 @@ function map_ssas_to_source(src::CodeInfo, rootnode::SyntaxNode, Δline::Int)
736780
end
737781
return mappings, symtyps
738782
end
739-
map_ssas_to_source(src::CodeInfo, rootnode::SyntaxNode, Δline::Integer) = map_ssas_to_source(src, rootnode, Int(Δline))
783+
map_ssas_to_source(src::CodeInfo, mi::MethodInstance, rootnode::SyntaxNode, Δline::Integer) = map_ssas_to_source(src, mi, rootnode, Int(Δline))
740784

741785
function follow_back(src, arg)
742786
# Follow SSAValue backward to see if it maps back to a slot

TypedSyntax/test/exhaustive.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ const goodmis = Core.MethodInstance[]
2626
continue
2727
end
2828
try
29-
tsn, _ = TypedSyntax.tsn_and_mappings(m, src, rt, ret...; warn=false)
29+
tsn, _ = TypedSyntax.tsn_and_mappings(mi, src, rt, ret...; warn=false)
3030
@test isa(tsn, TypedSyntaxNode)
3131
push!(goodmis, mi)
3232
catch

TypedSyntax/test/runtests.jl

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using JuliaSyntax: JuliaSyntax, SyntaxNode, children, child, sourcetext, kind, @K_str
2-
using TypedSyntax: TypedSyntax, TypedSyntaxNode, getsrc
2+
using TypedSyntax: TypedSyntax, TypedSyntaxNode
33
using Dates, InteractiveUtils, Test
44

55
has_name_typ(node, name::Symbol, @nospecialize(T)) = kind(node) == K"Identifier" && node.val === name && node.typ === T
@@ -15,8 +15,7 @@ include("test_module.jl")
1515
"""
1616
rootnode = JuliaSyntax.parsestmt(SyntaxNode, st; filename="TSN1.jl")
1717
TSN.eval(Expr(rootnode))
18-
src, _ = getsrc(TSN.f, (Float32, Int, Float64))
19-
tsn = TypedSyntaxNode(rootnode, src)
18+
tsn = TypedSyntaxNode(rootnode, TSN.f, (Float32, Int, Float64))
2019
sig, body = children(tsn)
2120
@test children(sig)[2].typ === Float32
2221
@test children(sig)[3].typ === Int
@@ -33,8 +32,7 @@ include("test_module.jl")
3332
"""
3433
rootnode = JuliaSyntax.parsestmt(SyntaxNode, st; filename="TSN2.jl")
3534
TSN.eval(Expr(rootnode))
36-
src, _ = getsrc(TSN.g, (Int16, Int16, Int32))
37-
tsn = TypedSyntaxNode(rootnode, src)
35+
tsn = TypedSyntaxNode(rootnode, TSN.g, (Int16, Int16, Int32))
3836
sig, body = children(tsn)
3937
@test length(children(sig)) == 4
4038
@test children(body)[2].typ === Int32
@@ -46,8 +44,7 @@ include("test_module.jl")
4644
st = "math(x) = x + sin(x + π / 4)"
4745
rootnode = JuliaSyntax.parsestmt(SyntaxNode, st; filename="TSN2.jl")
4846
TSN.eval(Expr(rootnode))
49-
src, _ = getsrc(TSN.math, (Int,))
50-
tsn = TypedSyntaxNode(rootnode, src)
47+
tsn = TypedSyntaxNode(rootnode, TSN.math, (Int,))
5148
sig, body = children(tsn)
5249
@test has_name_typ(child(body, 1), :x, Int)
5350
@test has_name_typ(child(body, 3, 2, 1), :x, Int)
@@ -70,8 +67,7 @@ include("test_module.jl")
7067
st = "math2(x) = sin(x) + sin(x)"
7168
rootnode = JuliaSyntax.parsestmt(SyntaxNode, st; filename="TSN2.jl")
7269
TSN.eval(Expr(rootnode))
73-
src, _ = getsrc(TSN.math2, (Int,))
74-
tsn = TypedSyntaxNode(rootnode, src)
70+
tsn = TypedSyntaxNode(rootnode, TSN.math2, (Int,))
7571
sig, body = children(tsn)
7672
@test body.typ === Float64
7773
@test_broken child(body, 1).typ === Float64
@@ -91,8 +87,7 @@ include("test_module.jl")
9187
)
9288
rootnode = JuliaSyntax.parsestmt(SyntaxNode, st; filename="TSN3.jl")
9389
TSN.eval(Expr(rootnode))
94-
src, _ = getsrc(TSN.firstfirst, (Vector{Vector{Real}},))
95-
tsn = TypedSyntaxNode(rootnode, src)
90+
tsn = TypedSyntaxNode(rootnode, TSN.firstfirst, (Vector{Vector{Real}},))
9691
sig, body = children(tsn)
9792
@test child(body, idxsinner...).typ === nothing
9893
@test child(body, idxsouter...).typ === Vector{Real}
@@ -150,8 +145,7 @@ include("test_module.jl")
150145
"""
151146
rootnode = JuliaSyntax.parsestmt(SyntaxNode, st; filename="TSN4.jl")
152147
TSN.eval(Expr(rootnode))
153-
src, rt = getsrc(TSN.setlist!, (Vector{Vector{Float32}}, Vector{Vector{UInt8}}, Int, Int))
154-
tsn = TypedSyntaxNode(rootnode, src)
148+
tsn = TypedSyntaxNode(rootnode, TSN.setlist!, (Vector{Vector{Float32}}, Vector{Vector{UInt8}}, Int, Int))
155149
sig, body = children(tsn)
156150
nodelist = child(body, 1, 2, 1, 1) # `listget`
157151
@test sourcetext(nodelist) == "listget" && nodelist.typ === Vector{Vector{UInt8}}
@@ -175,8 +169,7 @@ include("test_module.jl")
175169
"""
176170
rootnode = JuliaSyntax.parsestmt(SyntaxNode, st; filename="TSN5.jl")
177171
TSN.eval(Expr(rootnode))
178-
src, rt = getsrc(TSN.callfindmin, (Vector{Float64},))
179-
tsn = TypedSyntaxNode(rootnode, src)
172+
tsn = TypedSyntaxNode(rootnode, TSN.callfindmin, (Vector{Float64},))
180173
sig, body = children(tsn)
181174
t = child(body, 1, 1)
182175
@test kind(t) == K"tuple"
@@ -280,18 +273,18 @@ include("test_module.jl")
280273
"""
281274
rootnode = JuliaSyntax.parsestmt(SyntaxNode, st; filename="TSN6.jl")
282275
TSN.eval(Expr(rootnode))
283-
src, rt = getsrc(TSN.avoidzero, (Int,))
276+
inferred_result = TypedSyntax.get_inferred_result(TSN.avoidzero, (Int,))
277+
src, rt, mi = inferred_result.src, inferred_result.rt, inferred_result.mi
284278
# src looks like this:
285279
# %1 = Main.TSN.:(var"#avoidzero#6")(true, #self#, x)::Float64
286280
# return %1
287281
# Consequently there is nothing to match, but at least we shouldn't error
288-
tsn = TypedSyntaxNode(rootnode, src)
282+
tsn = TypedSyntaxNode(rootnode, src, mi)
289283
@test isa(tsn, TypedSyntaxNode)
290284
@test rt === Float64
291285
# Try the kwbodyfunc
292286
m = which(TSN.avoidzero, (Int,))
293-
src, rt = getsrc(Base.bodyfunction(m), (Bool, typeof(TSN.avoidzero), Int,))
294-
tsn = TypedSyntaxNode(rootnode, src)
287+
tsn = TypedSyntaxNode(rootnode, Base.bodyfunction(m), (Bool, typeof(TSN.avoidzero), Int,))
295288
sig, body = children(tsn)
296289
isz = child(body, 2, 1, 1)
297290
@test kind(isz) == K"call" && child(isz, 1).val == :iszero
@@ -520,8 +513,7 @@ include("test_module.jl")
520513
@test_broken body.typ == Int
521514

522515
# Construction from MethodInstance
523-
src, rt = TypedSyntax.getsrc(TSN.myoftype, (Float64, Int))
524-
tsn = TypedSyntaxNode(src.parent)
516+
tsn = TypedSyntaxNode(TSN.myoftype, (Float64, Int))
525517
sig, body = children(tsn)
526518
node = child(body, 1)
527519
@test node.typ === Type{Float64}
@@ -641,10 +633,10 @@ include("test_module.jl")
641633
@test isa(tsnc, TypedSyntaxNode)
642634

643635
# issue 487
644-
m = which(TSN.f487, (Int,))
645-
src, rt = getsrc(TSN.f487, (Int,))
636+
inferred_result = TypedSyntax.get_inferred_result(TSN.f487, (Int,))
637+
src, mi = inferred_result.src, inferred_result.mi
646638
rt = Core.Const(1)
647-
tsn, _ = TypedSyntax.tsn_and_mappings(m, src, rt)
639+
tsn, _ = TypedSyntax.tsn_and_mappings(mi, src, rt)
648640
@test_nowarn str = sprint(tsn; context=:color=>false) do io, obj
649641
printstyled(io, obj; hide_type_stable=false)
650642
end

src/codeview.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,6 @@ function cthulhu_typed(io::IO, debuginfo::Symbol,
182182
# we're working on pre-optimization state, need to ignore `LimitedAccuracy`
183183
src = copy(src)
184184
src.ssavaluetypes = mapany(ignorelimited, src.ssavaluetypes::Vector{Any})
185-
src.rettype = ignorelimited(src.rettype)
186185

187186
if src.slotnames !== nothing
188187
slotnames = Base.sourceinfo_slotnames(src)

src/interpreter.jl

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,12 +131,37 @@ function create_cthulhu_source(@nospecialize(opt), effects::Effects)
131131
return OptimizedSource(ir, opt.src, opt.src.inlineable, effects)
132132
end
133133

134+
@static if VERSION v"1.12.0-DEV.15"
135+
function CC.transform_result_for_cache(interp::CthulhuInterpreter,
136+
linfo::MethodInstance, valid_worlds::WorldRange, result::InferenceResult, can_discard_trees::Bool=false)
137+
return create_cthulhu_source(result.src, result.ipo_effects)
138+
end
139+
else
134140
function CC.transform_result_for_cache(interp::CthulhuInterpreter,
135141
linfo::MethodInstance, valid_worlds::WorldRange, result::InferenceResult)
136142
return create_cthulhu_source(result.src, result.ipo_effects)
137143
end
144+
end
138145

139-
@static if VERSION v"1.11.0-DEV.879"
146+
@static if VERSION v"1.12.0-DEV.45"
147+
function CC.src_inlining_policy(interp::CthulhuInterpreter,
148+
@nospecialize(src), @nospecialize(info::CCCallInfo), stmt_flag::UInt32)
149+
if isa(src, OptimizedSource)
150+
if CC.is_stmt_inline(stmt_flag) || src.isinlineable
151+
return true
152+
end
153+
return false
154+
else
155+
@assert src isa CC.IRCode || src === nothing "invalid Cthulhu code cache"
156+
# the default inlining policy may try additional effor to find the source in a local cache
157+
return @invoke CC.src_inlining_policy(interp::AbstractInterpreter,
158+
src::Any, info::CCCallInfo, stmt_flag::UInt32)
159+
end
160+
end
161+
CC.retrieve_ir_for_inlining(cached_result::CodeInstance, src::OptimizedSource) = CC.copy(src.ir)
162+
CC.retrieve_ir_for_inlining(mi::Core.MethodInstance, src::OptimizedSource, preserve_local_sources::Bool) =
163+
CC.retrieve_ir_for_inlining(mi, src.ir, preserve_local_sources)
164+
elseif VERSION v"1.11.0-DEV.879"
140165
function CC.inlining_policy(interp::CthulhuInterpreter,
141166
@nospecialize(src), @nospecialize(info::CCCallInfo), stmt_flag::UInt32)
142167
if isa(src, OptimizedSource)
@@ -181,7 +206,7 @@ function CC.IRInterpretationState(interp::CthulhuInterpreter,
181206
src = inferred.src
182207
method_info = CC.MethodInfo(src)
183208
return CC.IRInterpretationState(interp, method_info, ir, mi, argtypes, world,
184-
src.min_world, src.max_world)
209+
code.min_world, code.max_world)
185210
end
186211

187212
@static if VERSION v"1.11.0-DEV.737"

src/reflection.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -350,13 +350,12 @@ function add_sourceline!(locs, CI, stmtidx::Int)
350350
end
351351

352352
function get_typed_sourcetext(mi::MethodInstance, src::CodeInfo, @nospecialize(rt); warn::Bool=true)
353-
meth = mi.def::Method
354-
tsn, mappings = TypedSyntax.tsn_and_mappings(meth, src, rt; warn, strip_macros=true)
355-
return truncate_if_defaultargs!(tsn, mappings, meth)
353+
tsn, mappings = TypedSyntax.tsn_and_mappings(mi, src, rt; warn, strip_macros=true)
354+
return truncate_if_defaultargs!(tsn, mappings, mi.def::Method)
356355
end
357356

358357
function get_typed_sourcetext(mi::MethodInstance, ::IRCode, @nospecialize(rt); kwargs...)
359-
src, rt = TypedSyntax.getsrc(mi)
358+
src, rt = TypedSyntax.code_typed1_tsn(mi)
360359
return get_typed_sourcetext(mi, src, rt; kwargs...)
361360
end
362361

0 commit comments

Comments
 (0)