Skip to content

Commit 0c8ea4a

Browse files
authored
Fangrui/export systemc (#8)
* changed docker file * update to 24.04 & chisel 6.7.0 * update to chisel 6.7.0 * fix warning * update systemc * update make file * add busy mark in ctrl bundle
1 parent 03f366e commit 0c8ea4a

4 files changed

Lines changed: 151 additions & 4 deletions

File tree

src/main/scala/alu/mma/cu/controlUnit.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@ class ControlUnit(val n: Int = 8) extends Module {
1414
val cbus_out = Output(Vec(n * n, new NCoreMMALUCtrlBundle()))
1515
val cbus_dat_clct = Output(Bool())
1616
val cbus_use_accum = Output(Bool())
17+
val clct = Output(Bool())
1718
})
1819
// Assign each element with diagnal control signal
1920
val reg = RegInit(VecInit(Seq.fill(2*n-1)(0.U.asTypeOf(new NCoreMMALUCtrlBundle()))))
2021
// val clct = Wire(Bool())
2122
val or_g = Module{new ORGate(2*n-1)}
2223
io.cbus_dat_clct :<>= or_g.io.out
23-
24+
io.clct :<>= reg(2 * n - 2).busy
2425
io.cbus_use_accum :<>= reg(2 * n - 2).use_accum
2526

2627
// 1D systolic array for control

src/main/scala/alu/mma/mma.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ import chisel3._
3737
val ctrl_array = Module(new cu.ControlUnit(n))
3838
ctrl_array.io.cbus_in := io.ctrl
3939
dclct.io.dat_clct <> ctrl_array.io.cbus_dat_clct
40-
io.clct <> ctrl_array.io.cbus_dat_clct
40+
io.clct <> ctrl_array.io.clct
4141
dclct.io.use_accum <> ctrl_array.io.cbus_use_accum
4242

4343
val sarray = Module(new sa.SystolicArray2D(n, nbits))

src/main/scala/isa/micro_op/MMALUMicroCode.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@ import chisel3.util._
77
class NCoreMMALUCtrlBundle () extends Bundle {
88
val keep = Bool()
99
val use_accum = Bool()
10+
val busy = Bool()
1011
}

src/test/scala/alu/mma/MMALUSpec.scala

Lines changed: 147 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,10 @@ class MMALUSpec extends AnyFlatSpec {
7474
dut.io.ctrl.keep.poke(true)
7575
else
7676
dut.io.ctrl.keep.poke(false)
77+
if (i_tick < _n)
78+
dut.io.ctrl.busy.poke(true)
79+
else
80+
dut.io.ctrl.busy.poke(false)
7781
dut.io.ctrl.use_accum.poke(false)
7882

7983
// ideally, the array will give _n (diagnal) results per tick
@@ -181,6 +185,10 @@ class MMALUSpec extends AnyFlatSpec {
181185
dut.io.ctrl.keep.poke(true)
182186
else
183187
dut.io.ctrl.keep.poke(false)
188+
if (i_tick < 2 * _n)
189+
dut.io.ctrl.busy.poke(true)
190+
else
191+
dut.io.ctrl.busy.poke(false)
184192
dut.io.ctrl.use_accum.poke(false)
185193

186194
// ideally, the array will give _n (diagnal) results per tick
@@ -282,10 +290,14 @@ class MMALUSpec extends AnyFlatSpec {
282290
dut.io.ctrl.keep.poke(true)
283291
else
284292
dut.io.ctrl.keep.poke(false)
285-
if (i_tick < _n)
293+
if (i_tick < _n) {
294+
dut.io.ctrl.busy.poke(true)
286295
dut.io.ctrl.use_accum.poke(true)
287-
else
296+
}
297+
else {
298+
dut.io.ctrl.busy.poke(false)
288299
dut.io.ctrl.use_accum.poke(false)
300+
}
289301

290302
// ideally, the array will give _n (diagnal) results per tick
291303
dut.clock.step()
@@ -304,4 +316,137 @@ class MMALUSpec extends AnyFlatSpec {
304316
print_helper.printMatrix(_res, _n)
305317
}
306318
}
319+
320+
"MMALU" should "do a generic matrix multiplication in stream" in {
321+
simulate(new MMALU(new MMPE(8, 32), 4, 8, 32)) { dut =>
322+
val print_helper = new testUtil.PrintHelper()
323+
val _n = dut.n
324+
val rand = new Random
325+
val _mat_a = new Array[Int](_n * _n)
326+
val _mat_b = new Array[Int](_n * _n)
327+
val _mat_d = new Array[Int](_n * _n)
328+
val _mat_e = new Array[Int](_n * _n)
329+
val _vec_h = new Array[Int](_n)
330+
val _vec_i = new Array[Int](_n)
331+
val _expected_c = new Array[Int](_n * _n)
332+
val _expected_f = new Array[Int](_n * _n)
333+
var _res_c = new Array[Int](_n * _n)
334+
var _res_f = new Array[Int](_n * _n)
335+
var step = 0
336+
337+
// random initialize the
338+
for (i <- 0 until _n * _n) {
339+
_mat_a(i) = rand.between(-128, 128)
340+
_mat_b(i) = rand.between(-128, 128)
341+
_mat_d(i) = rand.between(-128, 128)
342+
_mat_e(i) = rand.between(-128, 128)
343+
}
344+
for (i <- 0 until _n) {
345+
_vec_h(i) = rand.between(-128, 128)
346+
_vec_i(i) = rand.between(-128, 128)
347+
}
348+
349+
// expected matrix multiplication result
350+
for (_i <- 0 until _n) {
351+
for (_j <- 0 until _n) {
352+
for (_m <- 0 until _n) {
353+
_expected_c(_i * _n + _j) += _mat_a(_m * _n + _j) * _mat_b(_m * _n + _i)
354+
_expected_f(_i * _n + _j) += _mat_d(_m * _n + _j) * _mat_e(_m * _n + _i)
355+
}
356+
_expected_c(_i * _n + _j) += _vec_h(_j)
357+
_expected_f(_i * _n + _j) += _vec_i(_j)
358+
}
359+
}
360+
361+
// print the expected results
362+
println("===== MAT A =====")
363+
print_helper.printMatrix(_mat_a, _n)
364+
println("===== MAT B =====")
365+
print_helper.printMatrix(_mat_b, _n)
366+
println("===== Vec H =====")
367+
print_helper.printVector(_vec_h, _n)
368+
println("+++++ MAT C +++++")
369+
print_helper.printMatrix(_expected_c, _n)
370+
println("===== MAT D =====")
371+
print_helper.printMatrix(_mat_d, _n)
372+
println("===== MAT E =====")
373+
print_helper.printMatrix(_mat_e, _n)
374+
println("===== Vec I =====")
375+
print_helper.printVector(_vec_i, _n)
376+
println("+++++ MAT F +++++")
377+
print_helper.printMatrix(_expected_f, _n)
378+
379+
// systolic arrays has latency of 3 * _n - 2
380+
// and the second work period is 2 * n - 1
381+
for (i_tick <- 0 until 3 * _n - 2 + _n) {
382+
383+
// poke the input vector
384+
// with data feeder the data latency is just n ticks
385+
// after n ticks the mmalu will have no dependency
386+
// on the current register
387+
if (i_tick < _n) {
388+
// println("Tick @ " + i_tick + " Reading Reg A & B")
389+
for (_i <- 0 until _n){
390+
dut.io.in_a(_i).poke(_mat_a(i_tick * _n + _i))
391+
dut.io.in_b(_i).poke(_mat_b(i_tick * _n + _i))
392+
dut.io.in_accum(_i).poke(_vec_h(_i))
393+
}
394+
} else if (i_tick < 2 * _n) {
395+
// println("Tick @ " + i_tick + " Reading Reg C & D")
396+
for (_i <- 0 until _n){
397+
dut.io.in_a(_i).poke(_mat_d((i_tick % _n) * _n + _i))
398+
dut.io.in_b(_i).poke(_mat_e((i_tick % _n) * _n + _i))
399+
dut.io.in_accum(_i).poke(_vec_i(_i))
400+
}
401+
} else {
402+
for (_i <- 0 until _n){
403+
dut.io.in_a(_i).poke(0)
404+
dut.io.in_b(_i).poke(0)
405+
dut.io.in_accum(_i).poke(0)
406+
}
407+
}
408+
409+
// Only the first _n ticks need accumlate signal
410+
// The rest of the control signal will hand over
411+
// to a dedicated systolic-ish control bus
412+
if (i_tick != _n - 1 && i_tick != 2 * _n )
413+
dut.io.ctrl.keep.poke(true)
414+
else
415+
dut.io.ctrl.keep.poke(false)
416+
if (i_tick < 2 * _n)
417+
dut.io.ctrl.busy.poke(true)
418+
else
419+
dut.io.ctrl.busy.poke(false)
420+
dut.io.ctrl.use_accum.poke(true)
421+
422+
// ideally, the array will give _n (diagnal) results per tick
423+
dut.clock.step()
424+
425+
// systolic array will start to spit out after _n - 1 ticks for mat_c
426+
println("Tick @ " + i_tick + " clct signal " + dut.io.clct.peek().litValue.toInt)
427+
if (i_tick >= 2 * _n - 2 && i_tick < 3 * _n - 2) {
428+
for (_i <- 0 until _n) {
429+
_res_c(step * _n + _i) = dut.io.out(_i).peek().litValue.toInt
430+
println("Tick @ " + i_tick + " Mat C producing at location (" + _i + ", " + step + "): " + _res_c(step * _n + _i))
431+
dut.io.out(_i).expect(_expected_c(step * _n + _i))
432+
}
433+
step = step + 1
434+
}
435+
// systolic array will start to generate again after 3 * _n - 2 ticks
436+
if (i_tick >= 3 * _n - 2 && i_tick < 4 * _n - 2){
437+
for (_i <- 0 until _n) {
438+
_res_f((step % _n) * _n + _i) = dut.io.out(_i).peek().litValue.toInt
439+
println("OUT Tick @ " + i_tick + " Mat F producing at location (" + _i + ", " + step % _n + "): " + _res_f((step % _n) * _n + _i))
440+
dut.io.out(_i).expect(_expected_f((step % _n) * _n + _i))
441+
}
442+
step = step + 1
443+
}
444+
445+
}
446+
println("+++++ MAT C from HW ++++")
447+
print_helper.printMatrix(_res_c, _n)
448+
println("+++++ MAT F from HW ++++")
449+
print_helper.printMatrix(_res_f, _n)
450+
}
451+
}
307452
}

0 commit comments

Comments
 (0)