diff --git a/gotypes/components.go b/gotypes/components.go index dd224b3..2206afa 100644 --- a/gotypes/components.go +++ b/gotypes/components.go @@ -143,7 +143,7 @@ func (c *component) Imag() Component { } func (c *component) Index(i int) Component { - a, ok := c.typ.(*types.Array) + a, ok := c.typ.Underlying().(*types.Array) if !ok { return errorf("not array type") } @@ -213,17 +213,17 @@ func (c *component) sub(suffix string, offset int, t types.Type) *component { } func isslice(t types.Type) bool { - _, ok := t.(*types.Slice) + _, ok := t.Underlying().(*types.Slice) return ok } func isstring(t types.Type) bool { - b, ok := t.(*types.Basic) + b, ok := t.Underlying().(*types.Basic) return ok && b.Kind() == types.String } func iscomplex(t types.Type) bool { - b, ok := t.(*types.Basic) + b, ok := t.Underlying().(*types.Basic) return ok && (b.Info()&types.IsComplex) != 0 } diff --git a/gotypes/components_test.go b/gotypes/components_test.go index b0b2366..2a045e0 100644 --- a/gotypes/components_test.go +++ b/gotypes/components_test.go @@ -1,6 +1,7 @@ package gotypes import ( + "go/token" "go/types" "strings" "testing" @@ -96,3 +97,132 @@ func TestComponentErrorChaining(t *testing.T) { } } } + +func TestComponentDeconstruction(t *testing.T) { + cases := []struct { + Name string + Type types.Type + Chain func(Component) Component + Param string + Offset int + }{ + { + Name: "slice_base", + Type: types.NewSlice(types.Typ[types.Uint64]), + Chain: func(c Component) Component { return c.Base() }, + Param: "base", + Offset: 0, + }, + { + Name: "slice_len", + Type: types.NewSlice(types.Typ[types.Uint64]), + Chain: func(c Component) Component { return c.Len() }, + Param: "len", + Offset: 8, + }, + { + Name: "slice_cap", + Type: types.NewSlice(types.Typ[types.Uint64]), + Chain: func(c Component) Component { return c.Cap() }, + Param: "cap", + Offset: 16, + }, + { + Name: "string_base", + Type: types.Typ[types.String], + Chain: func(c Component) Component { return c.Base() }, + Param: "base", + Offset: 0, + }, + { + Name: "slice_len", + Type: types.Typ[types.String], + Chain: func(c Component) Component { return c.Len() }, + Param: "len", + Offset: 8, + }, + { + Name: "complex64_real", + Type: types.Typ[types.Complex64], + Chain: func(c Component) Component { return c.Real() }, + Param: "real", + Offset: 0, + }, + { + Name: "complex64_imag", + Type: types.Typ[types.Complex64], + Chain: func(c Component) Component { return c.Imag() }, + Param: "imag", + Offset: 4, + }, + { + Name: "complex128_real", + Type: types.Typ[types.Complex128], + Chain: func(c Component) Component { return c.Real() }, + Param: "real", + Offset: 0, + }, + { + Name: "complex128_imag", + Type: types.Typ[types.Complex128], + Chain: func(c Component) Component { return c.Imag() }, + Param: "imag", + Offset: 8, + }, + { + Name: "array", + Type: types.NewArray(types.Typ[types.Uint32], 7), + Chain: func(c Component) Component { return c.Index(3) }, + Param: "3", + Offset: 12, + }, + { + Name: "struct", + Type: types.NewStruct([]*types.Var{ + types.NewField(token.NoPos, nil, "Byte", types.Typ[types.Byte], false), + types.NewField(token.NoPos, nil, "Uint64", types.Typ[types.Uint64], false), + }, nil), + Chain: func(c Component) Component { return c.Field("Uint64") }, + Param: "Uint64", + Offset: 8, + }, + } + + // For every test case, generate the same case but when the type is wrapped in + // a named type. + n := len(cases) + for i := 0; i < n; i++ { + wrapped := cases[i] + wrapped.Name += "_wrapped" + wrapped.Type = types.NewNamed( + types.NewTypeName(token.NoPos, nil, "wrapped", nil), + wrapped.Type, + nil, + ) + cases = append(cases, wrapped) + } + + for _, c := range cases { + c := c // avoid scopelint error + t.Run(c.Name, func(t *testing.T) { + t.Log(c.Type) + base := operand.NewParamAddr("test", 0) + comp := NewComponent(c.Type, base) + comp = c.Chain(comp) + + b, err := comp.Resolve() + if err != nil { + t.Fatal(err) + } + + expectname := "test_" + c.Param + if b.Addr.Symbol.Name != expectname { + t.Errorf("parameter name %q; expected %q", b.Addr.Symbol.Name, expectname) + } + + if b.Addr.Disp != c.Offset { + t.Errorf("offset %d; expected %d", b.Addr.Disp, c.Offset) + } + }) + } +}