Update asm generation scripts.

This commit is contained in:
Ken MacKay
2014-05-13 22:53:12 -07:00
parent 95becf0327
commit b14b86e32f
2 changed files with 281 additions and 108 deletions
+54 -50
View File
@@ -26,66 +26,70 @@ def emit(line, *args):
print s % args
#### set up registers
emit("adiw r30, %s" % (size - init_size)) # move z
emit("adiw r28, %s" % (size - init_size)) # move y
emit("adiw r30, %s", size - init_size) # move z
emit("adiw r28, %s", size - init_size) # move y
for i in xrange(init_size):
emit("ld r%s, x+", rx(i))
for i in xrange(init_size):
emit("ld r%s, y+", ry(i))
# note that this code assume that init_size > 1
#### first two multiplications of initial block
emit("ldi r25, 0")
print ""
emit("ldi r23, 0")
emit("mul r2, r12")
emit("st z+, r0")
emit("mov r22, r1")
print ""
emit("ldi r24, 0")
emit("mul r2, r13")
emit("add r22, r0")
emit("adc r23, r1")
emit("mul r3, r12")
emit("add r22, r0")
emit("adc r23, r1")
emit("adc r24, r25")
emit("st z+, r22")
print ""
if init_size == 1:
emit("mul r2, r12")
emit("st z+, r0")
emit("st z+, r1")
else:
#### first two multiplications of initial block
emit("ldi r23, 0")
emit("mul r2, r12")
emit("st z+, r0")
emit("mov r22, r1")
print ""
emit("ldi r24, 0")
emit("mul r2, r13")
emit("add r22, r0")
emit("adc r23, r1")
emit("mul r3, r12")
emit("add r22, r0")
emit("adc r23, r1")
emit("adc r24, r25")
emit("st z+, r22")
print ""
#### rest of initial block, with moving accumulator registers
acc = [23, 24, 22]
for r in xrange(2, init_size):
emit("ldi r%s, 0", acc[2])
for i in xrange(0, r+1):
emit("mul r%s, r%s", rx(i), ry(r - i))
emit("add r%s, r0", acc[0])
emit("adc r%s, r1", acc[1])
emit("adc r%s, r25", acc[2])
#### rest of initial block, with moving accumulator registers
acc = [23, 24, 22]
for r in xrange(2, init_size):
emit("ldi r%s, 0", acc[2])
for i in xrange(0, r+1):
emit("mul r%s, r%s", rx(i), ry(r - i))
emit("add r%s, r0", acc[0])
emit("adc r%s, r1", acc[1])
emit("adc r%s, r25", acc[2])
emit("st z+, r%s", acc[0])
print ""
acc = acc[1:] + acc[:1]
for r in xrange(1, init_size-1):
emit("ldi r%s, 0", acc[2])
for i in xrange(0, init_size-r):
emit("mul r%s, r%s", rx(r+i), ry((init_size-1) - i))
emit("add r%s, r0", acc[0])
emit("adc r%s, r1", acc[1])
emit("adc r%s, r25", acc[2])
emit("st z+, r%s", acc[0])
print ""
acc = acc[1:] + acc[:1]
emit("mul r%s, r%s", rx(init_size-1), ry(init_size-1))
emit("add r%s, r0", acc[0])
emit("adc r%s, r1", acc[1])
emit("st z+, r%s", acc[0])
print ""
acc = acc[1:] + acc[:1]
for r in xrange(1, init_size-1):
emit("ldi r%s, 0", acc[2])
for i in xrange(0, init_size-r):
emit("mul r%s, r%s", rx(r+i), ry((init_size-1) - i))
emit("add r%s, r0", acc[0])
emit("adc r%s, r1", acc[1])
emit("adc r%s, r25", acc[2])
emit("st z+, r%s", acc[0])
print ""
acc = acc[1:] + acc[:1]
emit("mul r%s, r%s", rx(init_size-1), ry(init_size-1))
emit("add r%s, r0", acc[0])
emit("adc r%s, r1", acc[1])
emit("st z+, r%s", acc[0])
emit("st z+, r%s", acc[1])
emit("st z+, r%s", acc[1])
print ""
#### reset y and z pointers
emit("sbiw r30, %s" % (2 * init_size + 10))
emit("sbiw r28, %s" % (init_size + 10))
emit("sbiw r30, %s", 2 * init_size + 10)
emit("sbiw r28, %s", init_size + 10)
#### load y registers
for i in xrange(10):
@@ -186,9 +190,9 @@ for row in xrange(full_rows):
prev_size = prev_size + 10
if row < full_rows - 1:
#### reset x, y and z pointers
emit("sbiw r30, %s" % (2 * prev_size + 10))
emit("sbiw r28, %s" % (prev_size + 10))
emit("sbiw r26, %s" % (prev_size))
emit("sbiw r30, %s", 2 * prev_size + 10)
emit("sbiw r28, %s", prev_size + 10)
emit("sbiw r26, %s", prev_size)
#### load x and y registers
for i in xrange(10):
+227 -58
View File
@@ -1,68 +1,191 @@
#!/usr/bin/env python
def r(i):
import sys
if len(sys.argv) < 2:
print "Provide the integer size in bytes"
sys.exit(1)
size = int(sys.argv[1])
if size > 40:
print "This script doesn't work with integer size %s due to laziness" % (size)
sys.exit(1)
init_size = size - 20
if size < 20:
init_size = 0
def rg(i):
return i + 2
def lo(i):
return i + 2
def hi(i):
return i + 12
def emit(line, *args):
s = '"' + line + r' \n\t"'
print s % args
#### set up registers
zero = "r25"
emit("ldi %s, 0", zero) # zero register
for i in xrange(20):
emit("ld r%s, x+", r(i))
if init_size > 0:
emit("movw r28, r26") # y = x
h = (init_size + 1)//2
for i in xrange(h):
emit("ld r%s, x+", lo(i))
emit("adiw r28, %s", size - init_size) # move y to other end
for i in xrange(h):
emit("ld r%s, y+", hi(i))
emit("adiw r30, %s", size - init_size) # move z
if init_size == 1:
emit("mul %s, %s", lo(0), hi(0))
emit("st z+, r0")
emit("st z+, r1")
else:
#### first one
print ""
emit("ldi r23, 0")
emit("mul %s, %s", lo(0), hi(0))
emit("st z+, r0")
emit("mov r22, r1")
print ""
#### rest of initial block, with moving accumulator registers
acc = [22, 23, 24]
for r in xrange(1, h):
emit("ldi r%s, 0", acc[2])
for i in xrange(0, (r+2)//2):
emit("mul r%s, r%s", lo(i), hi(r - i))
emit("add r%s, r0", acc[0])
emit("adc r%s, r1", acc[1])
emit("adc r%s, %s", acc[2], zero)
emit("st z+, r%s", acc[0])
print ""
acc = acc[1:] + acc[:1]
lo_r = range(2, 2 + h)
hi_r = range(12, 12 + h)
# now we need to start loading more from the high end
for r in xrange(h, init_size):
hi_r = hi_r[1:] + hi_r[:1]
emit("ld r%s, y+", hi_r[h-1])
emit("ldi r%s, 0", acc[2])
for i in xrange(0, (r+2)//2):
emit("mul r%s, r%s", lo(i), hi_r[h - 1 - i])
emit("add r%s, r0", acc[0])
emit("adc r%s, r1", acc[1])
emit("adc r%s, %s", acc[2], zero)
emit("st z+, r%s", acc[0])
print ""
acc = acc[1:] + acc[:1]
# loaded all of the high end bytes; now need to start loading the rest of the low end
for r in xrange(1, init_size-h):
lo_r = lo_r[1:] + lo_r[:1]
emit("ld r%s, x+", lo_r[h-1])
emit("ldi r%s, 0", acc[2])
for i in xrange(0, (init_size+1 - r)//2):
emit("mul r%s, r%s", lo_r[i], hi_r[h - 1 - i])
emit("add r%s, r0", acc[0])
emit("adc r%s, r1", acc[1])
emit("adc r%s, %s", acc[2], zero)
emit("st z+, r%s", acc[0])
print ""
acc = acc[1:] + acc[:1]
lo_r = lo_r[1:] + lo_r[:1]
emit("ld r%s, x+", lo_r[h-1])
# now we have loaded everything, and we just need to finish the last corner
for r in xrange(init_size-h, init_size-1):
emit("ldi r%s, 0", acc[2])
for i in xrange(0, (init_size+1 - r)//2):
emit("mul r%s, r%s", lo_r[i], hi_r[h - 1 - i])
emit("add r%s, r0", acc[0])
emit("adc r%s, r1", acc[1])
emit("adc r%s, %s", acc[2], zero)
emit("st z+, r%s", acc[0])
print ""
acc = acc[1:] + acc[:1]
lo_r = lo_r[1:] + lo_r[:1] # make the indexing easy
emit("mul r%s, r%s", lo_r[0], hi_r[h - 1])
emit("add r%s, r0", acc[0])
emit("adc r%s, r1", acc[1])
emit("st z+, r%s", acc[0])
emit("st z+, r%s", acc[1])
print ""
emit("sbiw r26, %s", init_size) # reset x
emit("sbiw r30, %s", size + init_size) # reset z
# TODO you could do more rows of size 20 here if your integers are larger than 40 bytes
s = size - init_size
for i in xrange(s):
emit("ld r%s, x+", rg(i))
#### first few columns
emit("ldi r27, 0") # zero register
# NOTE: this is only valid if size >= 3
print ""
emit("ldi r23, 0")
emit("mul r2, r2")
emit("mul r%s, r%s", rg(0), rg(0))
emit("st z+, r0")
emit("mov r22, r1")
print ""
emit("ldi r24, 0")
emit("mul r2, r3")
emit("lsl r0")
emit("rol r1")
emit("adc r24, r27") # put carry bit in r24
emit("mul r%s, r%s", rg(0), rg(1))
emit("add r22, r0")
emit("adc r23, r1")
emit("adc r24, r27")
emit("adc r24, %s", zero)
emit("add r22, r0")
emit("adc r23, r1")
emit("adc r24, %s", zero)
emit("st z+, r22")
print ""
emit("ldi r22, 0")
emit("mul r2, r4")
emit("lsl r0")
emit("rol r1")
emit("adc r22, r27") # put carry bit in r22
emit("mul r%s, r%s", rg(0), rg(2))
emit("add r23, r0")
emit("adc r24, r1")
emit("adc r22, r27")
emit("mul r3, r3")
emit("adc r22, %s", zero)
emit("add r23, r0")
emit("adc r24, r1")
emit("adc r22, r27")
emit("adc r22, %s", zero)
emit("mul r%s, r%s", rg(1), rg(1))
emit("add r23, r0")
emit("adc r24, r1")
emit("adc r22, %s", zero)
emit("st z+, r23")
print ""
acc = [23, 24, 22]
old_acc = [25, 26]
for i in xrange(3, 20):
old_acc = [28, 29]
for i in xrange(3, s):
emit("ldi r%s, 0", old_acc[1])
tmp = [acc[1], acc[2]]
acc = [acc[0], old_acc[0], old_acc[1]]
old_acc = tmp
# gather non-equal words
for j in xrange(0, (i+1)//2):
emit("mul r%s, r%s", r(j), r(i-j))
if j == 0:
emit("mov r%s, r0", acc[0])
emit("mov r%s, r1", acc[1])
else:
emit("add r%s, r0", acc[0])
emit("adc r%s, r1", acc[1])
emit("adc r%s, r27", acc[2])
emit("mul r%s, r%s", rg(0), rg(i))
emit("mov r%s, r0", acc[0])
emit("mov r%s, r1", acc[1])
for j in xrange(1, (i+1)//2):
emit("mul r%s, r%s", rg(j), rg(i-j))
emit("add r%s, r0", acc[0])
emit("adc r%s, r1", acc[1])
emit("adc r%s, %s", acc[2], zero)
# multiply by 2
emit("lsl r%s", acc[0])
emit("rol r%s", acc[1])
@@ -70,52 +193,98 @@ for i in xrange(3, 20):
# add equal word (if any)
if ((i+1) % 2) != 0:
emit("mul r%s, r%s", r(i//2), r(i//2))
emit("mul r%s, r%s", rg(i//2), rg(i//2))
emit("add r%s, r0", acc[0])
emit("adc r%s, r1", acc[1])
emit("adc r%s, r27", acc[2])
emit("adc r%s, %s", acc[2], zero)
# add old accumulator
emit("add r%s, r%s", acc[0], old_acc[0])
emit("adc r%s, r%s", acc[1], old_acc[1])
emit("adc r%s, r27", acc[2])
emit("adc r%s, %s", acc[2], zero)
# store
emit("st z+, r%s", acc[0])
print ""
for i in xrange(1, 17):
regs = range(2, 22)
for i in xrange(init_size):
regs = regs[1:] + regs[:1]
emit("ld r%s, x+", regs[19])
for limit in [18, 19]:
emit("ldi r%s, 0", old_acc[1])
tmp = [acc[1], acc[2]]
acc = [acc[0], old_acc[0], old_acc[1]]
old_acc = tmp
# gather non-equal words
emit("mul r%s, r%s", regs[0], regs[limit])
emit("mov r%s, r0", acc[0])
emit("mov r%s, r1", acc[1])
for j in xrange(1, (limit+1)//2):
emit("mul r%s, r%s", regs[j], regs[limit-j])
emit("add r%s, r0", acc[0])
emit("adc r%s, r1", acc[1])
emit("adc r%s, %s", acc[2], zero)
emit("ld r0, z") # load stored value from initial block, and add to accumulator (note z does not increment)
emit("add r%s, r0", acc[0])
emit("adc r%s, r25", acc[1])
emit("adc r%s, r25", acc[2])
# multiply by 2
emit("lsl r%s", acc[0])
emit("rol r%s", acc[1])
emit("rol r%s", acc[2])
# add equal word
if limit == 18:
emit("mul r%s, r%s", regs[9], regs[9])
emit("add r%s, r0", acc[0])
emit("adc r%s, r1", acc[1])
emit("adc r%s, %s", acc[2], zero)
# add old accumulator
emit("add r%s, r%s", acc[0], old_acc[0])
emit("adc r%s, r%s", acc[1], old_acc[1])
emit("adc r%s, %s", acc[2], zero)
# store
emit("st z+, r%s", acc[0])
print ""
for i in xrange(1, s-3):
emit("ldi r%s, 0", old_acc[1])
tmp = [acc[1], acc[2]]
acc = [acc[0], old_acc[0], old_acc[1]]
old_acc = tmp
# gather non-equal words
for j in xrange(0, (20-i)//2):
emit("mul r%s, r%s", r(i+j), r(19-j))
if j == 0:
emit("mov r%s, r0", acc[0])
emit("mov r%s, r1", acc[1])
else:
emit("add r%s, r0", acc[0])
emit("adc r%s, r1", acc[1])
emit("adc r%s, r27", acc[2])
emit("mul r%s, r%s", regs[i], regs[s - 1])
emit("mov r%s, r0", acc[0])
emit("mov r%s, r1", acc[1])
for j in xrange(1, (s-i)//2):
emit("mul r%s, r%s", regs[i+j], regs[s - 1 - j])
emit("add r%s, r0", acc[0])
emit("adc r%s, r1", acc[1])
emit("adc r%s, %s", acc[2], zero)
# multiply by 2
emit("lsl r%s", acc[0])
emit("rol r%s", acc[1])
emit("rol r%s", acc[2])
# add equal word (if any)
if ((20-i) % 2) != 0:
emit("mul r%s, r%s", r(i + (20-i)//2), r(i + (20-i)//2))
if ((s-i) % 2) != 0:
emit("mul r%s, r%s", regs[i + (s-i)//2], regs[i + (s-i)//2])
emit("add r%s, r0", acc[0])
emit("adc r%s, r1", acc[1])
emit("adc r%s, r27", acc[2])
emit("adc r%s, %s", acc[2], zero)
# add old accumulator
emit("add r%s, r%s", acc[0], old_acc[0])
emit("adc r%s, r%s", acc[1], old_acc[1])
emit("adc r%s, r27", acc[2])
emit("adc r%s, %s", acc[2], zero)
# store
emit("st z+, r%s", acc[0])
@@ -123,33 +292,33 @@ for i in xrange(1, 17):
acc = acc[1:] + acc[:1]
emit("ldi r%s, 0", acc[2])
emit("mul r19, r21")
emit("lsl r0")
emit("rol r1")
emit("adc r%s, r27", acc[2])
emit("mul r%s, r%s", regs[17], regs[19])
emit("add r%s, r0", acc[0])
emit("adc r%s, r1", acc[1])
emit("adc r%s, r27", acc[2])
emit("mul r20, r20")
emit("adc r%s, %s", acc[2], zero)
emit("add r%s, r0", acc[0])
emit("adc r%s, r1", acc[1])
emit("adc r%s, r27", acc[2])
emit("adc r%s, %s", acc[2], zero)
emit("mul r%s, r%s", regs[18], regs[18])
emit("add r%s, r0", acc[0])
emit("adc r%s, r1", acc[1])
emit("adc r%s, %s", acc[2], zero)
emit("st z+, r%s", acc[0])
print ""
acc = acc[1:] + acc[:1]
emit("ldi r%s, 0", acc[2])
emit("mul r20, r21")
emit("lsl r0")
emit("rol r1")
emit("adc r%s, r27", acc[2])
emit("mul r%s, r%s", regs[18], regs[19])
emit("add r%s, r0", acc[0])
emit("adc r%s, r1", acc[1])
emit("adc r%s, r27", acc[2])
emit("adc r%s, %s", acc[2], zero)
emit("add r%s, r0", acc[0])
emit("adc r%s, r1", acc[1])
emit("adc r%s, %s", acc[2], zero)
emit("st z+, r%s", acc[0])
print ""
emit("mul r21, r21")
emit("mul r%s, r%s", regs[19], regs[19])
emit("add r%s, r0", acc[1])
emit("adc r%s, r1", acc[2])
emit("st z+, r%s", acc[1])