From 301d0c137aded6cec77817a08625d33d965201ec Mon Sep 17 00:00:00 2001 From: Michael McLoughlin Date: Fri, 4 Jan 2019 00:45:01 -0800 Subject: [PATCH] internal/stack: helper package for querying stack frames Intended for #5. Also replaces a helper function in the `printer` package. --- internal/stack/stack.go | 72 ++++++++++++++++++++++++++++++++++++ internal/stack/stack_test.go | 34 +++++++++++++++++ printer/printer.go | 19 ++-------- 3 files changed, 109 insertions(+), 16 deletions(-) create mode 100644 internal/stack/stack.go create mode 100644 internal/stack/stack_test.go diff --git a/internal/stack/stack.go b/internal/stack/stack.go new file mode 100644 index 0000000..5944c89 --- /dev/null +++ b/internal/stack/stack.go @@ -0,0 +1,72 @@ +package stack + +import ( + "path" + "runtime" + "strings" +) + +// Frames returns at most max callstack Frames, starting with its caller and +// skipping skip Frames. +func Frames(skip, max int) []runtime.Frame { + pc := make([]uintptr, max) + n := runtime.Callers(skip+2, pc) + if n == 0 { + return nil + } + pc = pc[:n] + frames := runtime.CallersFrames(pc) + var fs []runtime.Frame + for { + f, more := frames.Next() + fs = append(fs, f) + if !more { + break + } + } + return fs +} + +// Match returns the first stack frame for which the predicate function returns +// true. Returns nil if no match is found. Starts matching after skip frames, +// starting with its caller. +func Match(skip int, predicate func(runtime.Frame) bool) *runtime.Frame { + i, n := skip+1, 16 + for { + fs := Frames(i, n) + for _, f := range fs { + if predicate(f) { + return &f + } + } + if len(fs) < n { + break + } + i += n + } + return nil +} + +// Main returns the main() function Frame. +func Main() *runtime.Frame { + return Match(1, func(f runtime.Frame) bool { + return f.Function == "main.main" + }) +} + +// ExternalCaller returns the first frame outside the callers package. +func ExternalCaller() *runtime.Frame { + var first *runtime.Frame + return Match(1, func(f runtime.Frame) bool { + if first == nil { + first = &f + } + return pkg(first.Function) != pkg(f.Function) + }) +} + +func pkg(ident string) string { + dir, name := path.Split(ident) + parts := strings.Split(name, ".") + return dir + parts[0] +} diff --git a/internal/stack/stack_test.go b/internal/stack/stack_test.go new file mode 100644 index 0000000..4b833d1 --- /dev/null +++ b/internal/stack/stack_test.go @@ -0,0 +1,34 @@ +package stack_test + +import ( + "runtime" + "testing" + + "github.com/mmcloughlin/avo/internal/stack" +) + +const pkg = "github.com/mmcloughlin/avo/internal/stack_test" + +func TestFramesFirst(t *testing.T) { + fs := stack.Frames(0, 1) + if len(fs) == 0 { + t.Fatalf("empty slice") + } + got := fs[0].Function + expect := pkg + ".TestFramesFirst" + if got != expect { + t.Fatalf("bad function name %s; expect %s", got, expect) + } +} + +func TestMatchFirst(t *testing.T) { + first := stack.Match(0, func(_ runtime.Frame) bool { return true }) + if first == nil { + t.Fatalf("nil match") + } + got := first.Function + expect := pkg + ".TestMatchFirst" + if got != expect { + t.Fatalf("bad function name %s; expect %s", got, expect) + } +} diff --git a/printer/printer.go b/printer/printer.go index c237b18..76cc493 100644 --- a/printer/printer.go +++ b/printer/printer.go @@ -4,10 +4,10 @@ import ( "fmt" "os" "path/filepath" - "runtime" "strings" "github.com/mmcloughlin/avo" + "github.com/mmcloughlin/avo/internal/stack" ) type Printer interface { @@ -67,21 +67,8 @@ func (c Config) GeneratedWarning() string { // mainfile attempts to determine the file path of the main function by // inspecting the stack. Returns empty string on failure. func mainfile() string { - pc := make([]uintptr, 10) - n := runtime.Callers(0, pc) - if n == 0 { - return "" - } - pc = pc[:n] - frames := runtime.CallersFrames(pc) - for { - frame, more := frames.Next() - if frame.Function == "main.main" { - return frame.File - } - if !more { - break - } + if m := stack.Main(); m != nil { + return m.File } return "" }