diff --git a/ast.go b/ast.go index 1df68f2..5f2b9ef 100644 --- a/ast.go +++ b/ast.go @@ -67,6 +67,14 @@ func (i Instruction) TargetLabel() *Label { return nil } +func (i Instruction) Registers() []reg.Register { + var rs []reg.Register + for _, op := range i.Operands { + rs = append(rs, operand.Registers(op)...) + } + return rs +} + func (i Instruction) InputRegisters() []reg.Register { var rs []reg.Register for _, op := range i.Inputs { @@ -107,6 +115,9 @@ type Function struct { // LabelTarget maps from label name to the following instruction. LabelTarget map[Label]*Instruction + + // Register allocation. + Allocation reg.Allocation } func NewFunction(name string) *Function { diff --git a/operand/types.go b/operand/types.go index f4b4677..d53ac31 100644 --- a/operand/types.go +++ b/operand/types.go @@ -70,3 +70,16 @@ func Registers(op Op) []reg.Register { } panic("unknown operand type") } + +// ApplyAllocation returns an operand with allocated registers replaced. Registers missing from the allocation are left alone. +func ApplyAllocation(op Op, a reg.Allocation) Op { + switch op := op.(type) { + case reg.Register: + return a.LookupDefault(op) + case Mem: + op.Base = a.LookupDefault(op.Base) + op.Index = a.LookupDefault(op.Index) + return op + } + return op +} diff --git a/pass/alloc.go b/pass/alloc.go new file mode 100644 index 0000000..5dbf0ec --- /dev/null +++ b/pass/alloc.go @@ -0,0 +1,153 @@ +package pass + +import ( + "errors" + "math" + + "github.com/mmcloughlin/avo/reg" +) + +// edge is an edge of the interference graph, indicating that registers X and Y +// must be in non-conflicting registers. +type edge struct { + X, Y reg.Register +} + +type Allocator struct { + registers []reg.Physical + allocation reg.Allocation + edges []*edge + possible map[reg.Virtual][]reg.Physical +} + +func NewAllocator(rs []reg.Physical) (*Allocator, error) { + if len(rs) == 0 { + return nil, errors.New("no registers") + } + return &Allocator{ + registers: rs, + allocation: reg.NewEmptyAllocation(), + possible: map[reg.Virtual][]reg.Physical{}, + }, nil +} + +func NewAllocatorForKind(k reg.Kind) (*Allocator, error) { + f := reg.FamilyOfKind(k) + if f == nil { + return nil, errors.New("unknown register family") + } + return NewAllocator(f.Registers()) +} + +func (a *Allocator) AddInterferenceSet(r reg.Register, s reg.Set) { + for y := range s { + a.AddInterference(r, y) + } +} + +func (a *Allocator) AddInterference(x, y reg.Register) { + a.add(x) + a.add(y) + a.edges = append(a.edges, &edge{X: x, Y: y}) +} + +func (a *Allocator) Allocate() (reg.Allocation, error) { + for a.remaining() > 0 { + a.update() + + v := a.mostrestricted() + if err := a.alloc(v); err != nil { + return nil, err + } + } + return a.allocation, nil +} + +// add adds a register. +func (a *Allocator) add(r reg.Register) { + v, ok := r.(reg.Virtual) + if !ok { + return + } + a.possible[v] = a.registersofsize(v.Bytes()) +} + +// update possible allocations based on edges. +func (a *Allocator) update() error { + var rem []*edge + for _, e := range a.edges { + e.X, e.Y = a.allocation.LookupDefault(e.X), a.allocation.LookupDefault(e.Y) + + px, py := reg.ToPhysical(e.X), reg.ToPhysical(e.Y) + vx, vy := reg.ToVirtual(e.X), reg.ToVirtual(e.Y) + + switch { + case vx != nil && vy != nil: + rem = append(rem, e) + continue + case px != nil && py != nil: + if reg.AreConflicting(px, py) { + return errors.New("impossible register allocation") + } + case px != nil && vy != nil: + a.discardconflicting(vy, px) + case vx != nil && py != nil: + a.discardconflicting(vx, py) + default: + panic("unreachable") + } + } + a.edges = rem + return nil +} + +// mostrestricted returns the virtual register with the least possibilities. +func (a *Allocator) mostrestricted() reg.Virtual { + n := int(math.MaxInt32) + var v reg.Virtual + for r, p := range a.possible { + if len(p) < n { + n = len(p) + v = r + } + } + return v +} + +// discardconflicting removes registers from vs possible list that conflict with p. +func (a *Allocator) discardconflicting(v reg.Virtual, p reg.Physical) { + var rs []reg.Physical + for _, r := range a.possible[v] { + if !reg.AreConflicting(r, p) { + rs = append(rs, r) + } + } + a.possible[v] = rs +} + +// alloc attempts to allocate a register to v. +func (a *Allocator) alloc(v reg.Virtual) error { + ps := a.possible[v] + if len(ps) == 0 { + return errors.New("failed to allocate registers") + } + a.allocation[v] = ps[0] + delete(a.possible, v) + return nil +} + +// remaining returns the number of unallocated registers. +func (a *Allocator) remaining() int { + return len(a.possible) +} + +// registersofsize returns all registers of the given size. +func (a *Allocator) registersofsize(n uint) []reg.Physical { + var rs []reg.Physical + for _, r := range a.registers { + if r.Bytes() == n { + rs = append(rs, r) + } + } + return rs +} diff --git a/pass/reg.go b/pass/reg.go index b69defb..e0e6f17 100644 --- a/pass/reg.go +++ b/pass/reg.go @@ -1,12 +1,18 @@ package pass import ( + "errors" + "github.com/mmcloughlin/avo" + "github.com/mmcloughlin/avo/operand" "github.com/mmcloughlin/avo/reg" ) // Liveness computes register liveness. func Liveness(fn *avo.Function) error { + // Note this implementation is initially naive so as to be "obviously correct". + // There are a well-known optimizations we can apply if necessary. + is := fn.Instructions() // Initialize to empty sets. @@ -51,5 +57,58 @@ func Liveness(fn *avo.Function) error { } func AllocateRegisters(fn *avo.Function) error { + // Build one allocator per register kind and record register interferences. + as := map[reg.Kind]*Allocator{} + for _, i := range fn.Instructions() { + for _, d := range i.OutputRegisters() { + k := d.Kind() + if _, found := as[k]; !found { + a, err := NewAllocatorForKind(k) + if err != nil { + return err + } + as[k] = a + } + + out := i.LiveOut.OfKind(k) + out.Discard(d) + as[k].AddInterferenceSet(d, out) + } + } + + // Execute register allocation. + fn.Allocation = reg.NewEmptyAllocation() + for _, a := range as { + al, err := a.Allocate() + if err != nil { + return err + } + if err := fn.Allocation.Merge(al); err != nil { + return err + } + } + + return nil +} + +func BindRegisters(fn *avo.Function) error { + for _, i := range fn.Instructions() { + for idx := range i.Operands { + i.Operands[idx] = operand.ApplyAllocation(i.Operands[idx], fn.Allocation) + } + } + return nil +} + +func VerifyAllocation(fn *avo.Function) error { + // All registers should be physical. + for _, i := range fn.Instructions() { + for _, r := range i.Registers() { + if reg.ToPhysical(r) == nil { + return errors.New("non physical register found") + } + } + } + return nil } diff --git a/reg/reg_test.go b/reg/reg_test.go index 9243039..14ee949 100644 --- a/reg/reg_test.go +++ b/reg/reg_test.go @@ -23,14 +23,43 @@ func TestSpecBytes(t *testing.T) { } } -func TestVirtualPhysicalHaveDifferentIDs(t *testing.T) { - // Confirm that ID() returns different results even when virtual and physical IDs are the same. - var v Virtual = virtual{id: 42} - var p Physical = register{id: 42} - if uint16(v.VirtualID()) != uint16(p.PhysicalID()) { - t.Fatal("test assumption violated: VirtualID and PhysicalID should agree") +func TestToVirtual(t *testing.T) { + v := GeneralPurpose.Virtual(42, B32) + if ToVirtual(v) != v { + t.Errorf("ToVirtual(v) != v for virtual register") } - if v.ID() == p.ID() { - t.Errorf("virtual and physical IDs should be different") + if ToVirtual(ECX) != nil { + t.Errorf("ToVirtual should be nil for physical registers") + } +} + +func TestToPhysical(t *testing.T) { + v := GeneralPurpose.Virtual(42, B32) + if ToPhysical(v) != nil { + t.Errorf("ToPhysical should be nil for virtual registers") + } + if ToPhysical(ECX) != ECX { + t.Errorf("ToPhysical(p) != p for physical register") + } +} + +func TestAreConflicting(t *testing.T) { + cases := []struct { + X, Y Physical + Expect bool + }{ + {ECX, X3, false}, + {AL, AH, false}, + {AL, AX, true}, + {AL, BX, false}, + {X3, Y4, false}, + {X3, Y3, true}, + {Y3, Z4, false}, + {Y3, Z3, true}, + } + for _, c := range cases { + if AreConflicting(c.X, c.Y) != c.Expect { + t.Errorf("AreConflicting(%s, %s) != %v", c.X, c.Y, c.Expect) + } } } diff --git a/reg/set.go b/reg/set.go index 3fb9a41..fd1ddf3 100644 --- a/reg/set.go +++ b/reg/set.go @@ -1,7 +1,7 @@ package reg // Set is a set of registers. -type Set map[ID]Register +type Set map[Register]bool // NewEmptySet builds an empty register set. func NewEmptySet() Set { @@ -20,7 +20,7 @@ func NewSetFromSlice(rs []Register) Set { // Clone returns a copy of s. func (s Set) Clone() Set { c := NewEmptySet() - for _, r := range s { + for r := range s { c.Add(r) } return c @@ -28,17 +28,17 @@ func (s Set) Clone() Set { // Add r to s. func (s Set) Add(r Register) { - s[r.ID()] = r + s[r] = true } // Discard removes r from s, if present. func (s Set) Discard(r Register) { - delete(s, r.ID()) + delete(s, r) } // Update adds every register in t to s. func (s Set) Update(t Set) { - for _, r := range t { + for r := range t { s.Add(r) } } @@ -52,7 +52,7 @@ func (s Set) Difference(t Set) Set { // DifferenceUpdate removes every element of t from s. func (s Set) DifferenceUpdate(t Set) { - for _, r := range t { + for r := range t { s.Discard(r) } } @@ -62,10 +62,21 @@ func (s Set) Equals(t Set) bool { if len(s) != len(t) { return false } - for _, r := range s { - if _, found := t[r.ID()]; !found { + for r := range s { + if _, found := t[r]; !found { return false } } return true } + +// OfKind returns the set of elements of s with kind k. +func (s Set) OfKind(k Kind) Set { + t := NewEmptySet() + for r := range s { + if r.Kind() == k { + t.Add(r) + } + } + return t +} diff --git a/reg/set_test.go b/reg/set_test.go index 14683fe..7fde164 100644 --- a/reg/set_test.go +++ b/reg/set_test.go @@ -2,11 +2,40 @@ package reg import "testing" -func TestFamilyRegisterSets(t *testing.T) { - fs := []*Family{GeneralPurpose, SIMD} - for _, f := range fs { - if len(f.Set()) != len(f.Registers()) { - t.Fatal("family set and list should have same size") - } +func TestSetRegisterIdentity(t *testing.T) { + rs := []Register{ + NewVirtual(42, GP, B32), + NewVirtual(43, GP, B32), + NewVirtual(42, SSEAVX, B32), + NewVirtual(42, GP, B64), + AL, AH, CL, + AX, R13W, + EDX, R9L, + RCX, R14, + X1, X7, + Y4, Y9, + Z13, Z31, + } + s := NewEmptySet() + for _, r := range rs { + s.Add(r) + s.Add(r) + } + if len(s) != len(rs) { + t.Fatalf("expected set to have same size as slice: got %d expect %d", len(s), len(rs)) + } +} + +func TestSetFamilyRegisters(t *testing.T) { + fs := []*Family{GeneralPurpose, SIMD} + s := NewEmptySet() + expect := 0 + for _, f := range fs { + s.Update(f.Set()) + s.Add(f.Virtual(42, B64)) + expect += len(f.Registers()) + 1 + } + if len(s) != expect { + t.Fatalf("set size mismatch: %d expected %d", len(s), expect) } } diff --git a/reg/types.go b/reg/types.go index 2c2547f..eb16cc5 100644 --- a/reg/types.go +++ b/reg/types.go @@ -1,6 +1,7 @@ package reg import ( + "errors" "fmt" ) @@ -63,7 +64,7 @@ type private interface { } type ( - ID uint32 + ID uint64 VID uint16 PID uint16 ) @@ -81,6 +82,14 @@ type Virtual interface { Register } +// ToVirtual converts r to Virtual if possible, otherwise returns nil. +func ToVirtual(r Register) Virtual { + if v, ok := r.(Virtual); ok { + return v + } + return nil +} + type virtual struct { id VID kind Kind @@ -99,7 +108,7 @@ func (v virtual) VirtualID() VID { return v.id } func (v virtual) Kind() Kind { return v.kind } func (v virtual) ID() ID { - return (ID(1) << 31) | ID(v.VirtualID()) + return (ID(1) << 63) | (ID(v.Size) << 24) | (ID(v.kind) << 16) | ID(v.VirtualID()) } func (v virtual) Asm() string { @@ -115,6 +124,14 @@ type Physical interface { Register } +// ToPhysical converts r to Physical if possible, otherwise returns nil. +func ToPhysical(r Register) Physical { + if p, ok := r.(Physical); ok { + return p + } + return nil +} + type register struct { id PID kind Kind @@ -155,3 +172,34 @@ func (s Spec) Bytes() uint { x := uint(s) return (x >> 1) + (x & 1) } + +// AreConflicting returns whether registers conflict with each other. +func AreConflicting(x, y Physical) bool { + return x.Kind() == y.Kind() && x.PhysicalID() == y.PhysicalID() && (x.Mask()&y.Mask()) != 0 +} + +// Allocation records a register allocation. +type Allocation map[Register]Physical + +func NewEmptyAllocation() Allocation { + return Allocation{} +} + +// Merge allocations from b into a. Errors if there is disagreement on a common +// register. +func (a Allocation) Merge(b Allocation) error { + for r, p := range b { + if alt, found := a[r]; found && alt != p { + return errors.New("disagreement on overlapping register") + } + a[r] = p + } + return nil +} + +func (a Allocation) LookupDefault(r Register) Register { + if p, found := a[r]; found { + return p + } + return r +} diff --git a/reg/x86.go b/reg/x86.go index 606c3a9..9385231 100644 --- a/reg/x86.go +++ b/reg/x86.go @@ -8,6 +8,23 @@ const ( Mask ) +var Families = []*Family{ + GeneralPurpose, + SIMD, +} + +var familiesByKind = map[Kind]*Family{} + +func init() { + for _, f := range Families { + familiesByKind[f.Kind] = f + } +} + +func FamilyOfKind(k Kind) *Family { + return familiesByKind[k] +} + // General purpose registers. var ( GeneralPurpose = &Family{Kind: GP}