Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/CthulhuBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ using WidthLimitedIO

using Core: MethodInstance, MethodMatch
using Core.IR
using .CC: AbstractInterpreter, ApplyCallInfo, CallInfo as CCCallInfo, ConstCallInfo,
using .CC: AbstractInterpreter, CallMeta, ApplyCallInfo, CallInfo as CCCallInfo, ConstCallInfo,
EFFECTS_TOTAL, Effects, IncrementalCompact, InferenceParams, InferenceResult,
InferenceState, IRCode, LimitedAccuracy, MethodMatchInfo, MethodResultPure,
NativeInterpreter, NoCallInfo, OptimizationParams, OptimizationState,
Expand Down Expand Up @@ -469,7 +469,7 @@ function _descend(term::AbstractTerminal, interp::AbstractInterpreter, curs::Abs
if sourcenode !== nothing
show_sub_callsites = let callsite=callsite
map(info.callinfos) do ci
p = Base.unwrap_unionall(ci.def.specTypes).parameters
p = Base.unwrap_unionall(get_ci(ci).def.specTypes).parameters
if isa(sourcenode, TypedSyntax.MaybeTypedSyntaxNode) && length(p) == length(JuliaSyntax.children(sourcenode)) + 1
newnode = copy(sourcenode)
for (i, child) in enumerate(JuliaSyntax.children(newnode))
Expand Down
8 changes: 8 additions & 0 deletions src/callsite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,14 @@ get_ci(gci::CuCallInfo) = get_ci(gci.ci)
get_rt(gci::CuCallInfo) = get_rt(gci.ci)
get_effects(gci::CuCallInfo) = get_effects(gci.ci)

struct CthulhuCallInfo <: CCCallInfo
meta::CallMeta
end
CC.add_edges_impl(edges::Vector{Any}, info::CthulhuCallInfo) = CC.add_edges!(edges, info.meta.info)
CC.nsplit_impl(info::CthulhuCallInfo) = CC.nsplit(info.meta.info)
CC.getsplit_impl(info::CthulhuCallInfo, idx::Int) = CC.getsplit(info.meta.info, idx)
CC.getresult_impl(info::CthulhuCallInfo, idx::Int) = CC.getresult(info.meta.info, idx)

struct Callsite
id::Int # ssa-id
info::CallInfo
Expand Down
2 changes: 1 addition & 1 deletion src/codeview.jl
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ function add_callsites!(d::AbstractDict, visited_cis::AbstractSet, diagnostics::
# e.g. if f(x) = x is called with different types we print nothing.
key = (mi.def.file, mi.def.line)
if haskey(d, key)
if !isnothing(d[key]) && mi != d[key].mi
if !isnothing(d[key]) && mi != d[key].ci.def
d[key] = nothing
push!(diagnostics,
TypedSyntax.Diagnostic(
Expand Down
98 changes: 80 additions & 18 deletions src/interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ end
const InferenceKey = Union{CodeInstance,InferenceResult} # TODO make this `CodeInstance` fully
const InferenceDict{InferenceValue} = IdDict{InferenceKey, InferenceValue}
const PC2Remarks = Vector{Pair{Int, String}}
const PC2CallMeta = Dict{Int, CallMeta}
const PC2Effects = Dict{Int, Effects}
const PC2Excts = Dict{Int, Any}

Expand All @@ -29,6 +30,7 @@ struct CthulhuInterpreter <: AbstractInterpreter
native::AbstractInterpreter
unopt::InferenceDict{InferredSource}
remarks::InferenceDict{PC2Remarks}
calls::InferenceDict{PC2CallMeta}
effects::InferenceDict{PC2Effects}
exception_types::InferenceDict{PC2Excts}
end
Expand All @@ -39,6 +41,7 @@ function CthulhuInterpreter(interp::AbstractInterpreter=NativeInterpreter())
interp,
InferenceDict{InferredSource}(),
InferenceDict{PC2Remarks}(),
InferenceDict{PC2CallMeta}(),
InferenceDict{PC2Effects}(),
InferenceDict{PC2Excts}())
end
Expand Down Expand Up @@ -96,6 +99,21 @@ function CC.update_exc_bestguess!(interp::CthulhuInterpreter, @nospecialize(exct
frame::InferenceState)
end

function CC.abstract_call(interp::CthulhuInterpreter, arginfo::CC.ArgInfo, sstate::CC.StatementState, sv::InferenceState)
call = @invoke CC.abstract_call(interp::AbstractInterpreter, arginfo::CC.ArgInfo, sstate::CC.StatementState, sv::InferenceState)
if isa(sv, InferenceState)
key = get_inference_key(sv)
if key !== nothing
CC.Future{Any}(call, interp, sv) do call, interp, sv
calls = get!(PC2CallMeta, interp.calls, key)
calls[sv.currpc] = call
nothing
end
end
end
return call
end

function InferredSource(state::InferenceState)
unoptsrc = copy(state.src)
exct = state.result.exc_result
Expand All @@ -107,6 +125,66 @@ function InferredSource(state::InferenceState)
exct)
end

@static if VERSION ≥ v"1.13-"
function _finishinfer!(frame::InferenceState, interp::CthulhuInterpreter, cycleid::Int, opt_cache::IdDict{MethodInstance, CodeInstance})
return @invoke CC.finishinfer!(frame::InferenceState, interp::AbstractInterpreter, cycleid::Int, opt_cache::IdDict{MethodInstance, CodeInstance})
end
else
function _finishinfer!(frame::InferenceState, interp::CthulhuInterpreter, cycleid::Int)
return @invoke CC.finishinfer!(frame::InferenceState, interp::AbstractInterpreter, cycleid::Int)
end
end

function cthulhu_finish(result::Union{Nothing, InferenceResult}, frame::InferenceState, interp::CthulhuInterpreter)
key = get_inference_key(frame)
key === nothing && return result
interp.unopt[key] = InferredSource(frame)

# Wrap `CallInfo`s with `CthulhuCallInfo`s post-inference.
calls = get(interp.calls, key, nothing)
isnothing(calls) && return result
for (i, info) in enumerate(frame.stmt_info)
info === NoCallInfo() && continue
call = get(calls, i, nothing)
call === nothing && continue
if isa(info, CC.UnionSplitApplyCallInfo)
# XXX: `UnionSplitApplyCallInfo` is specially handled in `CC.inline_apply!`,
# so we can't shove it under a `CthulhuCallInfo`.
frame.stmt_info[i] = pack_cthulhuinfo_in_unionsplit(call, info)
else
frame.stmt_info[i] = CthulhuCallInfo(call)
end
end

return result
end

# Rebuild a `CC.UnionSplitApplyCallInfo` structure where inner `ApplyCallInfo`s wrap a `CthulhuCallInfo`.
# Note that technically, `rt`/`exct`/`effects`/`refinements` are incorrect for each apply call as they
# apply to the union split as a whole, not to individual branches. The idea is simply to preserve them.
function pack_cthulhuinfo_in_unionsplit(call::CallMeta, info::CC.UnionSplitApplyCallInfo)
infos = CC.ApplyCallInfo[]
for apply in info.infos
meta = CallMeta(call.rt, call.exct, call.effects, apply.call, call.refinements)
push!(infos, CC.ApplyCallInfo(CthulhuCallInfo(meta), apply.arginfo))
end
return CC.UnionSplitApplyCallInfo(infos)
end

# Build a `CthulhuCallInfo` structure wrapping `CC.UnionSplitApplyCallInfo`.
function unpack_cthulhuinfo_from_unionsplit(info::CC.UnionSplitApplyCallInfo)
isempty(info.infos) && return nothing
apply = info.infos[1]
isa(apply.call, CthulhuCallInfo) || return nothing
(; rt, exct, effects, refinements) = apply.call.meta
infos = CC.ApplyCallInfo[]
for apply in info.infos
push!(infos, CC.ApplyCallInfo(apply.call.meta.info, apply.arginfo))
end
call = CallMeta(rt, exct, effects, CC.UnionSplitApplyCallInfo(infos), refinements)
return CthulhuCallInfo(call)
end

function create_cthulhu_source(result::InferenceResult, effects::Effects)
isa(result.src, OptimizationState) || return result.src
opt = result.src
Expand All @@ -127,25 +205,9 @@ function set_cthulhu_source!(result::InferenceResult)
end

@static if VERSION ≥ v"1.13-"
CC.finishinfer!(state::InferenceState, interp::CthulhuInterpreter, cycleid::Int, opt_cache::IdDict{MethodInstance, CodeInstance}) = cthulhu_finish(CC.finishinfer!, state, interp, cycleid, opt_cache)
function cthulhu_finish(@specialize(finishfunc), state::InferenceState, interp::CthulhuInterpreter, cycleid::Int, opt_cache::IdDict{MethodInstance, CodeInstance})
res = @invoke finishfunc(state::InferenceState, interp::AbstractInterpreter, cycleid::Int, opt_cache::IdDict{MethodInstance, CodeInstance})
key = get_inference_key(state)
if key !== nothing
interp.unopt[key] = InferredSource(state)
end
return res
end
CC.finishinfer!(state::InferenceState, interp::CthulhuInterpreter, cycleid::Int, opt_cache::IdDict{MethodInstance, CodeInstance}) = cthulhu_finish(_finishinfer!(state, interp, cycleid, opt_cache), state, interp)
else
function cthulhu_finish(@specialize(finishfunc), state::InferenceState, interp::CthulhuInterpreter, cycleid::Int)
res = @invoke finishfunc(state::InferenceState, interp::AbstractInterpreter, cycleid::Int)
key = get_inference_key(state)
if key !== nothing
interp.unopt[key] = InferredSource(state)
end
return res
end
CC.finishinfer!(state::InferenceState, interp::CthulhuInterpreter, cycleid::Int) = cthulhu_finish(CC.finishinfer!, state, interp, cycleid)
CC.finishinfer!(state::InferenceState, interp::CthulhuInterpreter, cycleid::Int) = cthulhu_finish(_finishinfer!(state, interp, cycleid), state, interp)
end

function CC.finish!(interp::CthulhuInterpreter, caller::InferenceState, validation_world::UInt, time_before::UInt64)
Expand Down
28 changes: 19 additions & 9 deletions src/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,18 @@ function find_callsites(interp::AbstractInterpreter, CI::Union{CodeInfo,IRCode},
if stmt_infos !== nothing && is_call_expr(stmt, optimize)
info = stmt_infos[id]
if info !== nothing
rt = ignorelimited(argextype(SSAValue(id), CI, sptypes, slottypes))
if isa(info, CC.UnionSplitApplyCallInfo)
info = something(unpack_cthulhuinfo_from_unionsplit(info), info)
end
if isa(info, CthulhuCallInfo)
# We have a `CallMeta` available.
(; info, rt, exct, effects) = info.meta
@assert !isa(info, CthulhuCallInfo)
else
rt = ignorelimited(argextype(SSAValue(id), CI, sptypes, slottypes))
exct = isnothing(pc2excts) ? nothing : get(pc2excts, id, nothing)
effects = nothing
end
# in unoptimized IR, there may be `slot = rhs` expressions, which `argextype` doesn't handle
# so extract rhs for such an case
local args = stmt.args
Expand All @@ -50,8 +61,7 @@ function find_callsites(interp::AbstractInterpreter, CI::Union{CodeInfo,IRCode},
t = argextype(args[i], CI, sptypes, slottypes)
argtypes[i] = ignorelimited(t)
end
exct = isnothing(pc2excts) ? nothing : get(pc2excts, id, nothing)
callinfos = process_info(interp, info, argtypes, rt, optimize, exct)
callinfos = process_info(interp, info, argtypes, rt, optimize, exct, effects)
isempty(callinfos) && continue
callsite = let
if length(callinfos) == 1
Expand Down Expand Up @@ -146,8 +156,8 @@ end

function process_info(interp::AbstractInterpreter, @nospecialize(info::CCCallInfo),
argtypes::ArgTypes, @nospecialize(rt), optimize::Bool,
@nospecialize(exct))
process_recursive(@nospecialize(newinfo)) = process_info(interp, newinfo, argtypes, rt, optimize, exct)
@nospecialize(exct), effects::Union{Effects, Nothing})
process_recursive(@nospecialize(newinfo)) = process_info(interp, newinfo, argtypes, rt, optimize, exct, effects)

if isa(info, MethodResultPure)
if isa(info.info, CC.ReturnTypeCallInfo)
Expand All @@ -162,7 +172,7 @@ function process_info(interp::AbstractInterpreter, @nospecialize(info::CCCallInf
if edge === nothing
RTCallInfo(unwrapconst(argtypes[1]), argtypes[2:end], rt, exct)
else
effects = get_effects(edge)
effects = @something(effects, get_effects(edge))
EdgeCallInfo(edge, rt, effects, exct)
end
end for edge in info.edges if edge !== nothing]
Expand All @@ -183,7 +193,7 @@ function process_info(interp::AbstractInterpreter, @nospecialize(info::CCCallInf
elseif isa(info, CC.InvokeCallInfo)
edge = info.edge
if edge !== nothing
effects = get_effects(edge)
effects = @something(effects, get_effects(edge))
thisinfo = EdgeCallInfo(edge, rt, effects)
innerinfo = process_const_info(interp, thisinfo, argtypes, rt, info.result, optimize, exct)
else
Expand All @@ -194,7 +204,7 @@ function process_info(interp::AbstractInterpreter, @nospecialize(info::CCCallInf
elseif isa(info, CC.OpaqueClosureCallInfo)
edge = info.edge
if edge !== nothing
effects = get_effects(edge)
effects = @something(effects, get_effects(edge))
thisinfo = EdgeCallInfo(edge, rt, effects)
innerinfo = process_const_info(interp, thisinfo, argtypes, rt, info.result, optimize, exct)
else
Expand All @@ -212,7 +222,7 @@ function process_info(interp::AbstractInterpreter, @nospecialize(info::CCCallInf
return CallInfo[]
elseif isa(info, CC.ReturnTypeCallInfo)
newargtypes = argtypes[2:end]
callinfos = process_info(interp, info.info, newargtypes, unwrapType(widenconst(rt)), optimize, exct)
callinfos = process_info(interp, info.info, newargtypes, unwrapType(widenconst(rt)), optimize, exct, effects)
if length(callinfos) == 1
vmi = only(callinfos)
else
Expand Down
Loading