Skip to content

Commit 31ce47d

Browse files
committed
[REPLCompletions] improve implementation of completions
- Restrict method completion to ignore strictly less specific ones - Fix various lookup bugs - Improve slurping of final expression Inspired by #43572 Co-authored-by: Lionel Zoubritzky <[email protected]>
1 parent 5907ac3 commit 31ce47d

File tree

2 files changed

+235
-100
lines changed

2 files changed

+235
-100
lines changed

stdlib/REPL/src/REPLCompletions.jl

+79-90
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ using Base: propertynames, something
99

1010
abstract type Completion end
1111

12+
struct TextCompletion <: Completion
13+
text::String
14+
end
15+
1216
struct KeywordCompletion <: Completion
1317
keyword::String
1418
end
@@ -37,10 +41,7 @@ struct FieldCompletion <: Completion
3741
end
3842

3943
struct MethodCompletion <: Completion
40-
func
41-
input_types::Type
4244
method::Method
43-
orig_method::Union{Nothing,Method} # if `method` is a keyword method, keep the original method for sensible printing
4445
end
4546

4647
struct BslashCompletion <: Completion
@@ -58,7 +59,9 @@ end
5859

5960
# interface definition
6061
function Base.getproperty(c::Completion, name::Symbol)
61-
if name === :keyword
62+
if name === :text
63+
return getfield(c, :text)::String
64+
elseif name === :keyword
6265
return getfield(c, :keyword)::String
6366
elseif name === :path
6467
return getfield(c, :path)::String
@@ -84,13 +87,14 @@ function Base.getproperty(c::Completion, name::Symbol)
8487
return getfield(c, name)
8588
end
8689

90+
_completion_text(c::TextCompletion) = c.text
8791
_completion_text(c::KeywordCompletion) = c.keyword
8892
_completion_text(c::PathCompletion) = c.path
8993
_completion_text(c::ModuleCompletion) = c.mod
9094
_completion_text(c::PackageCompletion) = c.package
9195
_completion_text(c::PropertyCompletion) = string(c.property)
9296
_completion_text(c::FieldCompletion) = string(c.field)
93-
_completion_text(c::MethodCompletion) = sprint(io -> show(io, isnothing(c.orig_method) ? c.method : c.orig_method::Method))
97+
_completion_text(c::MethodCompletion) = repr(c.method)
9498
_completion_text(c::BslashCompletion) = c.bslash
9599
_completion_text(c::ShellCompletion) = c.text
96100
_completion_text(c::DictCompletion) = c.key
@@ -125,7 +129,7 @@ function filtered_mod_names(ffunc::Function, mod::Module, name::AbstractString,
125129
end
126130

127131
# REPL Symbol Completions
128-
function complete_symbol(sym::String, ffunc, context_module::Module=Main)
132+
function complete_symbol(sym::String, @nospecialize(ffunc), context_module::Module=Main)
129133
mod = context_module
130134
name = sym
131135

@@ -407,62 +411,48 @@ end
407411
# will show it consist of Expr, QuoteNode's and Symbol's which all needs to
408412
# be handled differently to iterate down to get the value of whitespace_chars.
409413
function get_value(sym::Expr, fn)
414+
if sym.head === :quote || sym.head === :inert
415+
return sym.args[1], true
416+
end
410417
sym.head !== :. && return (nothing, false)
411418
for ex in sym.args
419+
ex, found = get_value(ex, fn)
420+
!found && return (nothing, false)
412421
fn, found = get_value(ex, fn)
413422
!found && return (nothing, false)
414423
end
415424
return (fn, true)
416425
end
417426
get_value(sym::Symbol, fn) = isdefined(fn, sym) ? (getfield(fn, sym), true) : (nothing, false)
418-
get_value(sym::QuoteNode, fn) = isdefined(fn, sym.value) ? (getfield(fn, sym.value), true) : (nothing, false)
427+
get_value(sym::QuoteNode, fn) = (sym.value, true)
419428
get_value(sym::GlobalRef, fn) = get_value(sym.name, sym.mod)
420429
get_value(sym, fn) = (sym, true)
421430

422431
# Return the type of a getfield call expression
423432
function get_type_getfield(ex::Expr, fn::Module)
424433
length(ex.args) == 3 || return Any, false # should never happen, but just for safety
425-
obj, x = ex.args[2:3]
434+
fld, found = get_value(ex.args[3], fn)
435+
fld isa Symbol || return Any, false
436+
obj = ex.args[2]
426437
objt, found = get_type(obj, fn)
427-
objt isa DataType || return Any, false
428438
found || return Any, false
429-
if x isa QuoteNode
430-
fld = x.value
431-
elseif isexpr(x, :quote) || isexpr(x, :inert)
432-
fld = x.args[1]
433-
else
434-
fld = nothing # we don't know how to get the value of variable `x` here
435-
end
436-
fld isa Symbol || return Any, false
439+
objt isa DataType || return Any, false
437440
hasfield(objt, fld) || return Any, false
438441
return fieldtype(objt, fld), true
439442
end
440443

441-
# Determines the return type with Base.return_types of a function call using the type information of the arguments.
442-
function get_type_call(expr::Expr)
444+
# Determines the return type with the Compiler of a function call using the type information of the arguments.
445+
function get_type_call(expr::Expr, fn::Module)
443446
f_name = expr.args[1]
444-
# The if statement should find the f function. How f is found depends on how f is referenced
445-
if isa(f_name, GlobalRef) && isconst(f_name.mod,f_name.name) && isdefined(f_name.mod,f_name.name)
446-
ft = typeof(eval(f_name))
447-
found = true
448-
else
449-
ft, found = get_type(f_name, Main)
450-
end
447+
f, found = get_type(f_name, fn)
451448
found || return (Any, false) # If the function f is not found return Any.
452449
args = Any[]
453-
for ex in expr.args[2:end] # Find the type of the function arguments
454-
typ, found = get_type(ex, Main)
450+
for i in 2:length(expr.args) # Find the type of the function arguments
451+
typ, found = get_type(expr.args[i], fn)
455452
found ? push!(args, typ) : push!(args, Any)
456453
end
457-
# use _methods_by_ftype as the function is supplied as a type
458454
world = Base.get_world_counter()
459-
matches = Base._methods_by_ftype(Tuple{ft, args...}, -1, world)::Vector
460-
length(matches) == 1 || return (Any, false)
461-
match = first(matches)::Core.MethodMatch
462-
# Typeinference
463-
interp = Core.Compiler.NativeInterpreter()
464-
return_type = Core.Compiler.typeinf_type(interp, match.method, match.spec_types, match.sparams)
465-
return_type === nothing && return (Any, false)
455+
return_type = Core.Compiler.return_type(Tuple{f, args...}, world)
466456
return (return_type, true)
467457
end
468458

@@ -477,15 +467,15 @@ function try_get_type(sym::Expr, fn::Module)
477467
if a1 === :getfield || a1 === GlobalRef(Core, :getfield)
478468
return get_type_getfield(sym, fn)
479469
end
480-
return get_type_call(sym)
470+
return get_type_call(sym, fn)
481471
elseif sym.head === :thunk
482472
thk = sym.args[1]
483473
rt = ccall(:jl_infer_thunk, Any, (Any, Any), thk::Core.CodeInfo, fn)
484474
rt !== Any && return (rt, true)
485475
elseif sym.head === :ref
486476
# some simple cases of `expand`
487477
return try_get_type(Expr(:call, GlobalRef(Base, :getindex), sym.args...), fn)
488-
elseif sym.head === :. && sym.args[2] isa QuoteNode # second check catches broadcasting
478+
elseif sym.head === :. && sym.args[2] isa QuoteNode # second check catches broadcasting
489479
return try_get_type(Expr(:call, GlobalRef(Core, :getfield), sym.args...), fn)
490480
end
491481
return (Any, false)
@@ -525,37 +515,52 @@ function get_type(T, found::Bool, default_any::Bool)
525515
end
526516

527517
# Method completion on function call expression that look like :(max(1))
518+
MAX_METHOD_COMPLETIONS = 40
528519
function complete_methods(ex_org::Expr, context_module::Module=Main)
529-
func, found = get_value(ex_org.args[1], context_module)::Tuple{Any,Bool}
530-
!found && return Completion[]
520+
out = Completion[]
521+
funct, found = get_type(ex_org.args[1], context_module)::Tuple{Any,Bool}
522+
!found && return out
531523

532524
args_ex, kwargs_ex = complete_methods_args(ex_org.args[2:end], ex_org, context_module, true, true)
525+
push!(args_ex, Vararg{Any})
526+
complete_methods!(out, funct, args_ex, kwargs_ex, MAX_METHOD_COMPLETIONS::Int)
533527

534-
out = Completion[]
535-
complete_methods!(out, func, args_ex, kwargs_ex)
536528
return out
537529
end
538530

531+
MAX_ANY_METHOD_COMPLETIONS = 10
539532
function complete_any_methods(ex_org::Expr, callee_module::Module, context_module::Module, moreargs::Bool, shift::Bool)
540533
out = Completion[]
541534
args_ex, kwargs_ex = try
535+
# this may throw, since we set default_any to false
542536
complete_methods_args(ex_org.args[2:end], ex_org, context_module, false, false)
543-
catch
537+
catch ex
538+
ex isa ArgumentError || rethrow()
544539
return out
545540
end
541+
moreargs && push!(args_ex, Vararg{Any})
546542

543+
seen = Base.IdSet()
547544
for name in names(callee_module; all=true)
548545
if !Base.isdeprecated(callee_module, name) && isdefined(callee_module, name)
549546
func = getfield(callee_module, name)
550547
if !isa(func, Module)
551-
complete_methods!(out, func, args_ex, kwargs_ex, moreargs)
552-
elseif callee_module === Main::Module && isa(func, Module)
548+
funct = Core.Typeof(func)
549+
if !in(funct, seen)
550+
push!(seen, funct)
551+
complete_methods!(out, funct, args_ex, kwargs_ex, MAX_ANY_METHOD_COMPLETIONS::Int)
552+
end
553+
elseif callee_module === Main && isa(func, Module)
553554
callee_module2 = func
554555
for name in names(callee_module2)
555-
if isdefined(callee_module2, name)
556+
if !Base.isdeprecated(callee_module2, name) && isdefined(callee_module2, name)
556557
func = getfield(callee_module, name)
557558
if !isa(func, Module)
558-
complete_methods!(out, func, args_ex, kwargs_ex, moreargs)
559+
funct = Core.Typeof(func)
560+
if !in(funct, seen)
561+
push!(seen, funct)
562+
complete_methods!(out, funct, args_ex, kwargs_ex, MAX_ANY_METHOD_COMPLETIONS::Int)
563+
end
559564
end
560565
end
561566
end
@@ -566,7 +571,8 @@ function complete_any_methods(ex_org::Expr, callee_module::Module, context_modul
566571
if !shift
567572
# Filter out methods where all arguments are `Any`
568573
filter!(out) do c
569-
isa(c, REPLCompletions.MethodCompletion) || return true
574+
isa(c, TextCompletion) && return false
575+
isa(c, MethodCompletion) || return true
570576
sig = Base.unwrap_unionall(c.method.sig)::DataType
571577
return !all(T -> T === Any || T === Vararg{Any}, sig.parameters[2:end])
572578
end
@@ -577,7 +583,7 @@ end
577583

578584
function complete_methods_args(funargs::Vector{Any}, ex_org::Expr, context_module::Module, default_any::Bool, allow_broadcasting::Bool)
579585
args_ex = Any[]
580-
kwargs_ex = Pair{Symbol,Any}[]
586+
kwargs_ex = false
581587
if allow_broadcasting && ex_org.head === :. && ex_org.args[2] isa Expr
582588
# handle broadcasting, but only handle number of arguments instead of
583589
# argument types
@@ -587,13 +593,11 @@ function complete_methods_args(funargs::Vector{Any}, ex_org::Expr, context_modul
587593
else
588594
for ex in funargs
589595
if isexpr(ex, :parameters)
590-
for x in ex.args
591-
n, v = isexpr(x, :kw) ? (x.args...,) : (x, x)
592-
push!(kwargs_ex, n => get_type(get_type(v, context_module)..., default_any))
596+
if !isempty(ex.args)
597+
kwargs_ex = true
593598
end
594599
elseif isexpr(ex, :kw)
595-
n, v = (ex.args...,)
596-
push!(kwargs_ex, n => get_type(get_type(v, context_module)..., default_any))
600+
kwargs_ex = true
597601
else
598602
push!(args_ex, get_type(get_type(ex, context_module)..., default_any))
599603
end
@@ -602,34 +606,18 @@ function complete_methods_args(funargs::Vector{Any}, ex_org::Expr, context_modul
602606
return args_ex, kwargs_ex
603607
end
604608

605-
function complete_methods!(out::Vector{Completion}, @nospecialize(func), args_ex::Vector{Any}, kwargs_ex::Vector{Pair{Symbol,Any}}, moreargs::Bool=true)
606-
ml = methods(func)
609+
function complete_methods!(out::Vector{Completion}, @nospecialize(funct), args_ex::Vector{Any}, kwargs_ex::Bool, max_method_completions::Int)
607610
# Input types and number of arguments
608-
if isempty(kwargs_ex)
609-
t_in = Tuple{Core.Typeof(func), args_ex...}
610-
na = length(t_in.parameters)::Int
611-
orig_ml = fill(nothing, length(ml))
612-
else
613-
isdefined(ml.mt, :kwsorter) || return out
614-
kwfunc = ml.mt.kwsorter
615-
kwargt = NamedTuple{(first.(kwargs_ex)...,), Tuple{last.(kwargs_ex)...}}
616-
t_in = Tuple{Core.Typeof(kwfunc), kwargt, Core.Typeof(func), args_ex...}
617-
na = length(t_in.parameters)::Int
618-
orig_ml = ml # this method is supposed to be used for printing
619-
ml = methods(kwfunc)
620-
func = kwfunc
621-
end
622-
if !moreargs
623-
na = typemax(Int)
611+
t_in = Tuple{funct, args_ex...}
612+
m = Base._methods_by_ftype(t_in, nothing, max_method_completions, Base.get_world_counter(),
613+
#=ambig=# true, Ref(typemin(UInt)), Ref(typemax(UInt)), Ptr{Int32}(C_NULL))
614+
if m === false
615+
push!(out, TextCompletion(sprint(Base.show_signature_function, funct) * "( too many methods to show )"))
624616
end
625-
626-
for (method::Method, orig_method) in zip(ml, orig_ml)
627-
ms = method.sig
628-
629-
# Check if the method's type signature intersects the input types
630-
if typeintersect(Base.rewrap_unionall(Tuple{(Base.unwrap_unionall(ms)::DataType).parameters[1 : min(na, end)]...}, ms), t_in) != Union{}
631-
push!(out, MethodCompletion(func, t_in, method, orig_method))
632-
end
617+
m isa Vector || return
618+
for match in m
619+
# TODO: if kwargs_ex, filter out methods without kwargs?
620+
push!(out, MethodCompletion(match.method))
633621
end
634622
end
635623

@@ -708,7 +696,7 @@ function bslash_completions(string::String, pos::Int)
708696
return (false, (Completion[], 0:-1, false))
709697
end
710698

711-
function dict_identifier_key(str::String, tag::Symbol, context_module::Module = Main)
699+
function dict_identifier_key(str::String, tag::Symbol, context_module::Module=Main)
712700
if tag === :string
713701
str_close = str*"\""
714702
elseif tag === :cmd
@@ -897,21 +885,22 @@ function completions(string::String, pos::Int, context_module::Module=Main, shif
897885
dotpos < startpos && (dotpos = startpos - 1)
898886
s = string[startpos:pos]
899887
comp_keywords && append!(suggestions, complete_keyword(s))
900-
# The case where dot and start pos is equal could look like: "(""*"").d","". or CompletionFoo.test_y_array[1].y
901-
# This case can be handled by finding the beginning of the expression. This is done below.
902-
if dotpos == startpos
888+
# if the start of the string is a `.`, try to consume more input to get back to the beginning of the last expression
889+
if 0 < startpos <= lastindex(string) && string[startpos] == '.'
903890
i = prevind(string, startpos)
904891
while 0 < i
905892
c = string[i]
906-
if c in [')', ']']
907-
if c==')'
908-
c_start='('; c_end=')'
909-
elseif c==']'
910-
c_start='['; c_end=']'
893+
if c in (')', ']')
894+
if c == ')'
895+
c_start = '('
896+
c_end = ')'
897+
elseif c == ']'
898+
c_start = '['
899+
c_end = ']'
911900
end
912901
frange, end_of_identifier = find_start_brace(string[1:prevind(string, i)], c_start=c_start, c_end=c_end)
902+
isempty(frange) && break # unbalanced parens
913903
startpos = first(frange)
914-
startpos == 0 && break
915904
i = prevind(string, startpos)
916905
elseif c in ('\'', '\"', '\`')
917906
s = "$c$c"*string[startpos:pos]

0 commit comments

Comments
 (0)