init with basic distance calc

This commit is contained in:
qwjyh 2025-04-13 23:28:51 +09:00
commit fdc6c1a3fa
9 changed files with 283 additions and 0 deletions

17
src/MethodSimilarity.jl Normal file
View file

@ -0,0 +1,17 @@
module MethodSimilarity
using Statistics
using LinearAlgebra
using ProgressMeter
include("Utils.jl")
using .Utils
include("method_info.jl")
include("distance.jl")
include("analyze.jl")
greet() = print("Hello World!")
end # module MethodSimilarity

7
src/Utils.jl Normal file
View file

@ -0,0 +1,7 @@
module Utils
export Choose
include("utils/choose.jl")
end

21
src/analyze.jl Normal file
View file

@ -0,0 +1,21 @@
function main()
all_methods = Method[]
for base_prop_sym in propertynames(Base)
base_prop = getproperty(Base, base_prop_sym)
if !(base_prop isa Function)
continue
end
if !Base.isexported(Base, base_prop_sym) && !Base.ispublic(Base, base_prop_sym)
continue
end
push!(all_methods, methods(base_prop)...)
end
num_methods = length(all_methods)
@info "number of methods: $(num_methods)"
dists = zeros((num_methods, num_methods))
@showprogress Threads.@threads for (i, j) in Choose(2, num_methods) |> collect
dists[i, j] = distance(all_methods[i], all_methods[j])
end
dists
end

42
src/distance.jl Normal file
View file

@ -0,0 +1,42 @@
function distance_seqs(dist_func, x, y)
if length(x) > length(y)
x, y = y, x
end
diffs = minimum(Choose(length(x), length(y))) do y_ids
sum(enumerate(y_ids); init=0) do v
xid, yid = v
dist_func(x[xid], y[yid])
end
end
diffs / length(x)
end
function distance(x::MethodInfo, y::MethodInfo)::Float64
arg_types = if isempty(x.sig.types) || isempty(y.sig.types)
1
else
distance_seqs(x.sig.types, y.sig.types) do x_type, y_type
x_type == y_type
end
end
arg_names = if isempty(x.args) || isempty(y.args)
1
else
distance_seqs(x.args, y.args) do x_arg, y_arg
x_arg == y_arg
end
end
kwargs = if isempty(x.kwargs) || isempty(y.kwargs)
1
else
distance_seqs(x.kwargs, y.kwargs) do x_kw, y_kw
x_kw == y_kw
end
end
mean((arg_types, arg_names, kwargs))
end
function distance(x::Method, y::Method)
distance(get_method_info(x), get_method_info(y))
end

40
src/method_info.jl Normal file
View file

@ -0,0 +1,40 @@
struct MethodInfo
"type vars (maybe `TypeVar`s?)"
tv::Vector{Any}
"signature"
sig::Type
"argument names, including the function itself"
args::Vector{Symbol}
"keyword arguments"
kwargs::Vector{Symbol}
"definition location"
file::String
"definition location"
line::Int64
end
function get_method_info(m::Method)
tv = Any[]
sig = m.sig
while isa(sig, UnionAll)
push!(tv, sig.var)
sig = sig.body
end
file, line = Base.updated_methodloc(m)
argnames = Base.method_argnames(m)
if length(argnames) >= m.nargs
show_env = Base.ImmutableDict{Symbol,Any}()
for t in tv
show_env = Base.ImmutableDict(show_env, :unionall_env => t)
end
decls = Tuple{String,String}[Base.argtype_decl(show_env, argnames[i], sig, i, m.nargs, m.isva)
for i = 1:m.nargs]
# decls[1] = ("", sprint(Base.show_signature_function, Base.unwrapva(sig.parameters[1]), false, decls[1][1], html,
# context = show_env))
else
# decls = Tuple{String,String}[("", "") for i = 1:length(sig.parameters::SimpleVector)]
end
# return tv, decls, file, line
return MethodInfo(tv, sig, argnames, Base.kwarg_decl(m), file, line)
end

46
src/utils/choose.jl Normal file
View file

@ -0,0 +1,46 @@
struct Choose{M,N}
function Choose(m::Integer, n::Integer)
if m > n
m, n = n, m
end
new{m,n}()
end
end
function Base.iterate(::Choose{M,N}) where {M,N}
if M == 0
return nothing
end
ret = ntuple(identity, M)
ret, ret
end
function Base.iterate(::Choose{M,N}, state) where {M,N}
if state[1] == N - M + 1
return nothing
end
for (i, v) in Iterators.reverse(enumerate(state))
if v != N - M + i
ret = ntuple(M) do j
if j < i
state[j]
else
v + (j - i) + 1
end
end
return ret, ret
end
end
error("must not reachable")
end
Base.haslength(::Choose{M,N}) where {M,N} = Base.HasLength()
function Base.length(::Choose{M,N}) where {M,N}
binomial(N, M)
end
function Base.eltype(::Choose{M,N}) where {M,N}
NTuple{M, Int64}
end