diff --git a/ast.go b/ast.go index 84fa26d..b51dbe1 100644 --- a/ast.go +++ b/ast.go @@ -1,6 +1,8 @@ package avo import ( + "go/types" + "github.com/mmcloughlin/avo/operand" "github.com/mmcloughlin/avo/reg" ) @@ -109,7 +111,9 @@ func NewFile() *File { // Function represents an assembly function. type Function struct { - name string + Name string + Signature *types.Signature + Nodes []Node // LabelTarget maps from label name to the following instruction. @@ -121,7 +125,7 @@ type Function struct { func NewFunction(name string) *Function { return &Function{ - name: name, + Name: name, } } @@ -149,9 +153,6 @@ func (f *Function) Instructions() []*Instruction { return is } -// Name returns the function name. -func (f *Function) Name() string { return f.name } - // FrameBytes returns the size of the stack frame in bytes. func (f *Function) FrameBytes() int { // TODO(mbm): implement diff --git a/gotypes/gotypes.go b/gotypes/gotypes.go new file mode 100644 index 0000000..8981654 --- /dev/null +++ b/gotypes/gotypes.go @@ -0,0 +1,22 @@ +package gotypes + +import ( + "errors" + "go/token" + "go/types" +) + +func ParseSignature(expr string) (*types.Signature, error) { + tv, err := types.Eval(token.NewFileSet(), nil, token.NoPos, expr) + if err != nil { + return nil, err + } + if tv.Value != nil { + return nil, errors.New("signature expression should have nil value") + } + s, ok := tv.Type.(*types.Signature) + if !ok { + return nil, errors.New("provided type is not a function signature") + } + return s, nil +} diff --git a/gotypes/gotypes_test.go b/gotypes/gotypes_test.go new file mode 100644 index 0000000..fc18929 --- /dev/null +++ b/gotypes/gotypes_test.go @@ -0,0 +1,95 @@ +package gotypes + +import ( + "go/token" + "go/types" + "strings" + "testing" +) + +func TestParseSignature(t *testing.T) { + cases := []struct { + Expr string + ExpectParams *types.Tuple + ExpectReturn *types.Tuple + }{ + { + Expr: "func()", + }, + { + Expr: "func(x, y uint64)", + ExpectParams: types.NewTuple( + types.NewParam(token.NoPos, nil, "x", types.Typ[types.Uint64]), + types.NewParam(token.NoPos, nil, "y", types.Typ[types.Uint64]), + ), + }, + { + Expr: "func(n int, s []string) byte", + ExpectParams: types.NewTuple( + types.NewParam(token.NoPos, nil, "n", types.Typ[types.Int]), + types.NewParam(token.NoPos, nil, "s", types.NewSlice(types.Typ[types.String])), + ), + ExpectReturn: types.NewTuple( + types.NewParam(token.NoPos, nil, "", types.Typ[types.Byte]), + ), + }, + { + Expr: "func(x, y int) (x0, y0 int, s string)", + ExpectParams: types.NewTuple( + types.NewParam(token.NoPos, nil, "x", types.Typ[types.Int]), + types.NewParam(token.NoPos, nil, "y", types.Typ[types.Int]), + ), + ExpectReturn: types.NewTuple( + types.NewParam(token.NoPos, nil, "x0", types.Typ[types.Int]), + types.NewParam(token.NoPos, nil, "y0", types.Typ[types.Int]), + types.NewParam(token.NoPos, nil, "s", types.Typ[types.String]), + ), + }, + } + for _, c := range cases { + s, err := ParseSignature(c.Expr) + if err != nil { + t.Fatal(err) + } + if !TypesTuplesEqual(s.Params(), c.ExpectParams) { + t.Errorf("parameter mismatch\ngot %#v\nexpect %#v\n", s.Params(), c.ExpectParams) + } + if !TypesTuplesEqual(s.Results(), c.ExpectReturn) { + t.Errorf("return value(s) mismatch\ngot %#v\nexpect %#v\n", s.Results(), c.ExpectReturn) + } + } +} + +func TestParseSignatureErrors(t *testing.T) { + cases := []struct { + Expr string + ErrorContains string + }{ + {"idkjklol", "undeclared name"}, + {"struct{}", "not a function signature"}, + {"uint32(0xfeedbeef)", "should have nil value"}, + } + for _, c := range cases { + s, err := ParseSignature(c.Expr) + if s != nil || err == nil || !strings.Contains(err.Error(), c.ErrorContains) { + t.Errorf("expect error from expression %s\ngot: %s\nexpect substring: %s\n", c.Expr, err, c.ErrorContains) + } + } +} + +func TypesTuplesEqual(a, b *types.Tuple) bool { + if a.Len() != b.Len() { + return false + } + n := a.Len() + for i := 0; i < n; i++ { + if !TypesVarsEqual(a.At(i), b.At(i)) { + return false + } + } + return true +} + +func TypesVarsEqual(a, b *types.Var) bool { + return a.Name() == b.Name() && types.Identical(a.Type(), b.Type()) +} diff --git a/printer.go b/printer.go index 1eaf6ae..4090bb4 100644 --- a/printer.go +++ b/printer.go @@ -68,7 +68,7 @@ func (p *GoPrinter) multicomment(lines []string) { } func (p *GoPrinter) function(f *Function) { - p.printf("TEXT %s%s(SB),0,$%d-%d\n", dot, f.Name(), f.FrameBytes(), f.ArgumentBytes()) + p.printf("TEXT %s%s(SB),0,$%d-%d\n", dot, f.Name, f.FrameBytes(), f.ArgumentBytes()) for _, node := range f.Nodes { switch n := node.(type) {