@@ -74,6 +74,10 @@ class MMALUSpec extends AnyFlatSpec {
7474 dut.io.ctrl.keep.poke(true )
7575 else
7676 dut.io.ctrl.keep.poke(false )
77+ if (i_tick < _n)
78+ dut.io.ctrl.busy.poke(true )
79+ else
80+ dut.io.ctrl.busy.poke(false )
7781 dut.io.ctrl.use_accum.poke(false )
7882
7983 // ideally, the array will give _n (diagnal) results per tick
@@ -181,6 +185,10 @@ class MMALUSpec extends AnyFlatSpec {
181185 dut.io.ctrl.keep.poke(true )
182186 else
183187 dut.io.ctrl.keep.poke(false )
188+ if (i_tick < 2 * _n)
189+ dut.io.ctrl.busy.poke(true )
190+ else
191+ dut.io.ctrl.busy.poke(false )
184192 dut.io.ctrl.use_accum.poke(false )
185193
186194 // ideally, the array will give _n (diagnal) results per tick
@@ -282,10 +290,14 @@ class MMALUSpec extends AnyFlatSpec {
282290 dut.io.ctrl.keep.poke(true )
283291 else
284292 dut.io.ctrl.keep.poke(false )
285- if (i_tick < _n)
293+ if (i_tick < _n) {
294+ dut.io.ctrl.busy.poke(true )
286295 dut.io.ctrl.use_accum.poke(true )
287- else
296+ }
297+ else {
298+ dut.io.ctrl.busy.poke(false )
288299 dut.io.ctrl.use_accum.poke(false )
300+ }
289301
290302 // ideally, the array will give _n (diagnal) results per tick
291303 dut.clock.step()
@@ -304,4 +316,137 @@ class MMALUSpec extends AnyFlatSpec {
304316 print_helper.printMatrix(_res, _n)
305317 }
306318 }
319+
320+ " MMALU" should " do a generic matrix multiplication in stream" in {
321+ simulate(new MMALU (new MMPE (8 , 32 ), 4 , 8 , 32 )) { dut =>
322+ val print_helper = new testUtil.PrintHelper ()
323+ val _n = dut.n
324+ val rand = new Random
325+ val _mat_a = new Array [Int ](_n * _n)
326+ val _mat_b = new Array [Int ](_n * _n)
327+ val _mat_d = new Array [Int ](_n * _n)
328+ val _mat_e = new Array [Int ](_n * _n)
329+ val _vec_h = new Array [Int ](_n)
330+ val _vec_i = new Array [Int ](_n)
331+ val _expected_c = new Array [Int ](_n * _n)
332+ val _expected_f = new Array [Int ](_n * _n)
333+ var _res_c = new Array [Int ](_n * _n)
334+ var _res_f = new Array [Int ](_n * _n)
335+ var step = 0
336+
337+ // random initialize the
338+ for (i <- 0 until _n * _n) {
339+ _mat_a(i) = rand.between(- 128 , 128 )
340+ _mat_b(i) = rand.between(- 128 , 128 )
341+ _mat_d(i) = rand.between(- 128 , 128 )
342+ _mat_e(i) = rand.between(- 128 , 128 )
343+ }
344+ for (i <- 0 until _n) {
345+ _vec_h(i) = rand.between(- 128 , 128 )
346+ _vec_i(i) = rand.between(- 128 , 128 )
347+ }
348+
349+ // expected matrix multiplication result
350+ for (_i <- 0 until _n) {
351+ for (_j <- 0 until _n) {
352+ for (_m <- 0 until _n) {
353+ _expected_c(_i * _n + _j) += _mat_a(_m * _n + _j) * _mat_b(_m * _n + _i)
354+ _expected_f(_i * _n + _j) += _mat_d(_m * _n + _j) * _mat_e(_m * _n + _i)
355+ }
356+ _expected_c(_i * _n + _j) += _vec_h(_j)
357+ _expected_f(_i * _n + _j) += _vec_i(_j)
358+ }
359+ }
360+
361+ // print the expected results
362+ println(" ===== MAT A =====" )
363+ print_helper.printMatrix(_mat_a, _n)
364+ println(" ===== MAT B =====" )
365+ print_helper.printMatrix(_mat_b, _n)
366+ println(" ===== Vec H =====" )
367+ print_helper.printVector(_vec_h, _n)
368+ println(" +++++ MAT C +++++" )
369+ print_helper.printMatrix(_expected_c, _n)
370+ println(" ===== MAT D =====" )
371+ print_helper.printMatrix(_mat_d, _n)
372+ println(" ===== MAT E =====" )
373+ print_helper.printMatrix(_mat_e, _n)
374+ println(" ===== Vec I =====" )
375+ print_helper.printVector(_vec_i, _n)
376+ println(" +++++ MAT F +++++" )
377+ print_helper.printMatrix(_expected_f, _n)
378+
379+ // systolic arrays has latency of 3 * _n - 2
380+ // and the second work period is 2 * n - 1
381+ for (i_tick <- 0 until 3 * _n - 2 + _n) {
382+
383+ // poke the input vector
384+ // with data feeder the data latency is just n ticks
385+ // after n ticks the mmalu will have no dependency
386+ // on the current register
387+ if (i_tick < _n) {
388+ // println("Tick @ " + i_tick + " Reading Reg A & B")
389+ for (_i <- 0 until _n){
390+ dut.io.in_a(_i).poke(_mat_a(i_tick * _n + _i))
391+ dut.io.in_b(_i).poke(_mat_b(i_tick * _n + _i))
392+ dut.io.in_accum(_i).poke(_vec_h(_i))
393+ }
394+ } else if (i_tick < 2 * _n) {
395+ // println("Tick @ " + i_tick + " Reading Reg C & D")
396+ for (_i <- 0 until _n){
397+ dut.io.in_a(_i).poke(_mat_d((i_tick % _n) * _n + _i))
398+ dut.io.in_b(_i).poke(_mat_e((i_tick % _n) * _n + _i))
399+ dut.io.in_accum(_i).poke(_vec_i(_i))
400+ }
401+ } else {
402+ for (_i <- 0 until _n){
403+ dut.io.in_a(_i).poke(0 )
404+ dut.io.in_b(_i).poke(0 )
405+ dut.io.in_accum(_i).poke(0 )
406+ }
407+ }
408+
409+ // Only the first _n ticks need accumlate signal
410+ // The rest of the control signal will hand over
411+ // to a dedicated systolic-ish control bus
412+ if (i_tick != _n - 1 && i_tick != 2 * _n )
413+ dut.io.ctrl.keep.poke(true )
414+ else
415+ dut.io.ctrl.keep.poke(false )
416+ if (i_tick < 2 * _n)
417+ dut.io.ctrl.busy.poke(true )
418+ else
419+ dut.io.ctrl.busy.poke(false )
420+ dut.io.ctrl.use_accum.poke(true )
421+
422+ // ideally, the array will give _n (diagnal) results per tick
423+ dut.clock.step()
424+
425+ // systolic array will start to spit out after _n - 1 ticks for mat_c
426+ println(" Tick @ " + i_tick + " clct signal " + dut.io.clct.peek().litValue.toInt)
427+ if (i_tick >= 2 * _n - 2 && i_tick < 3 * _n - 2 ) {
428+ for (_i <- 0 until _n) {
429+ _res_c(step * _n + _i) = dut.io.out(_i).peek().litValue.toInt
430+ println(" Tick @ " + i_tick + " Mat C producing at location (" + _i + " , " + step + " ): " + _res_c(step * _n + _i))
431+ dut.io.out(_i).expect(_expected_c(step * _n + _i))
432+ }
433+ step = step + 1
434+ }
435+ // systolic array will start to generate again after 3 * _n - 2 ticks
436+ if (i_tick >= 3 * _n - 2 && i_tick < 4 * _n - 2 ){
437+ for (_i <- 0 until _n) {
438+ _res_f((step % _n) * _n + _i) = dut.io.out(_i).peek().litValue.toInt
439+ println(" OUT Tick @ " + i_tick + " Mat F producing at location (" + _i + " , " + step % _n + " ): " + _res_f((step % _n) * _n + _i))
440+ dut.io.out(_i).expect(_expected_f((step % _n) * _n + _i))
441+ }
442+ step = step + 1
443+ }
444+
445+ }
446+ println(" +++++ MAT C from HW ++++" )
447+ print_helper.printMatrix(_res_c, _n)
448+ println(" +++++ MAT F from HW ++++" )
449+ print_helper.printMatrix(_res_f, _n)
450+ }
451+ }
307452}
0 commit comments