diff --git a/finn-rtllib/dwc/hdl/axis_dwc.sv b/finn-rtllib/dwc/hdl/axis_dwc.sv new file mode 100644 index 0000000000..a482ebd5c9 --- /dev/null +++ b/finn-rtllib/dwc/hdl/axis_dwc.sv @@ -0,0 +1,90 @@ +// Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. +// +// This file is subject to the Xilinx Design License Agreement located +// in the LICENSE.md file in the root directory of this repository. +// +// This file contains confidential and proprietary information of Xilinx, Inc. +// and is protected under U.S. and international copyright and other +// intellectual property laws. +// +// DISCLAIMER +// This disclaimer is not a license and does not grant any rights to the materials +// distributed herewith. Except as otherwise provided in a valid license issued to +// you by Xilinx, and to the maximum extent permitted by applicable law: (1) THESE +// MATERIALS ARE MADE AVAILABLE "AS IS" AND WITH ALL FAULTS, AND XILINX HEREBY +// DISCLAIMS ALL WARRANTIES AND CONDITIONS, EXPRESS, IMPLIED, OR STATUTORY, +// INCLUDING BUT NOT LIMITED TO WARRANTIES OF MERCHANTABILITY, NONINFRINGEMENT, OR +// FITNESS FOR ANY PARTICULAR PURPOSE; and (2) Xilinx shall not be liable (whether +// in contract or tort, including negligence, or under any other theory of +// liability) for any loss or damage of any kind or nature related to, arising +// under or in connection with these materials, including for any direct, or any +// indirect, special, incidental, or consequential loss or damage (including loss +// of data, profits, goodwill, or any type of loss or damage suffered as a result +// of any action brought by a third party) even if such damage or loss was +// reasonably foreseeable or Xilinx had been advised of the possibility of the +// same. +// +// CRITICAL APPLICATIONS +// Xilinx products are not designed or intended to be fail-safe, or for use in +// any application requiring failsafe performance, such as life-support or safety +// devices or systems, Class III medical devices, nuclear facilities, applications +// related to the deployment of airbags, or any other applications that could lead +// to death, personal injury, or severe property or environmental damage +// (individually and collectively, "Critical Applications"). Customer assumes the +// sole risk and liability of any use of Xilinx products in Critical Applications, +// subject only to applicable laws and regulations governing limitations on product +// liability. +// +// THIS COPYRIGHT NOTICE AND DISCLAIMER MUST BE RETAINED AS PART OF THIS FILE AT ALL TIMES. + +module axis_dwc #( + parameter integer DEPTH = 512, + parameter integer S_DATA_BITS = 32, + parameter integer M_DATA_BITS = 8 +) ( + input logic aclk, + input logic aresetn, + + input logic s_axis_tvalid, + output logic s_axis_tready, + input logic [S_DATA_BITS-1:0] s_axis_tdata, + input logic [S_DATA_BITS/8-1:0] s_axis_tkeep, + input logic s_axis_tlast, + + output logic m_axis_tvalid, + input logic m_axis_tready, + output logic [M_DATA_BITS-1:0] m_axis_tdata, + output logic [M_DATA_BITS/8-1:0] m_axis_tkeep, + output logic m_axis_tlast +); + +axis_fifo_adapter #( + .DEPTH(DEPTH), + .S_DATA_WIDTH(S_DATA_BITS), + .M_DATA_WIDTH(M_DATA_BITS) +) inst_fifo_adapter ( + .clk (aclk), + .rst (~aresetn), + + .s_axis_tdata (s_axis_tdata), + .s_axis_tkeep (s_axis_tkeep), + .s_axis_tvalid (s_axis_tvalid), + .s_axis_tready (s_axis_tready), + .s_axis_tlast (s_axis_tlast), + .s_axis_tid ('0), + .s_axis_tdest ('0), + .s_axis_tuser ('0), + + .pause_req('0), + + .m_axis_tdata (m_axis_tdata), + .m_axis_tkeep (m_axis_tkeep), + .m_axis_tvalid (m_axis_tvalid), + .m_axis_tready (m_axis_tready), + .m_axis_tlast (m_axis_tlast), + .m_axis_tid (), + .m_axis_tdest (), + .m_axis_tuser () +); + +endmodule diff --git a/finn-rtllib/dynload/hdl/dynamic_load.sv b/finn-rtllib/dynload/hdl/dynamic_load.sv index 1b41b310a1..1918f4e7db 100644 --- a/finn-rtllib/dynload/hdl/dynamic_load.sv +++ b/finn-rtllib/dynload/hdl/dynamic_load.sv @@ -36,7 +36,8 @@ module dynamic_load #( int unsigned WEIGHT_WIDTH, int unsigned MH, int unsigned MW, - int unsigned N_REPS + int unsigned N_REPS, + parameter RAM_STYLE = "distributed" )( input logic ap_clk, input logic ap_rst_n, @@ -60,8 +61,6 @@ localparam int unsigned N_TLS = SF*NF; localparam int unsigned SIMD_BITS = (SIMD == 1) ? 1 : $clog2(SIMD); localparam int unsigned WGT_ADDR_BITS = (N_TLS == 1) ? 1 : $clog2(N_TLS); -localparam int unsigned RAM_BITS = (WEIGHT_WIDTH + 7)/8 * 8; -localparam int unsigned WGT_EN_BITS = RAM_BITS / 8; localparam int unsigned NF_BITS = (NF == 1) ? 1 : $clog2(NF); localparam int unsigned SF_BITS = (SF == 1) ? 1 : $clog2(SF); localparam int unsigned N_TLS_BITS = (N_TLS == 1) ? 1 : $clog2(N_TLS); @@ -85,9 +84,8 @@ logic[N_TLS_BITS-1:0] curr_sf_C = '0, curr_sf_N; logic[SIMD_BITS-1:0] curr_simd_C = '0, curr_simd_N; // -- Signals -logic [1:0][PE-1:0][SIMD-1:0][WGT_EN_BITS-1:0] a_we; // Bank enables +logic [1:0][SIMD-1:0] a_we; logic [1:0][WGT_ADDR_BITS-1:0] a_addr; -logic [1:0][PE-1:0][SIMD-1:0][WEIGHT_WIDTH-1:0] a_data_in; // -- Offsets for(genvar i = 0; i < NF; i++) begin @@ -147,12 +145,8 @@ always_comb begin : DP_PROC_WR // Buffers a_we = '0; - for(int i = 0; i < 2; i++) begin + for(int i = 0; i < 2; i++) a_addr[i] = offsets[curr_nf_C] + curr_sf_C; - for(int j = 0; j < PE; j++) - for(int k = 0; k < SIMD; k++) - a_data_in[i][j][k] = idat[j]; - end // Write and count case (state_wr_C) @@ -160,16 +154,7 @@ always_comb begin : DP_PROC_WR irdy = 1'b1; if(ivld) begin - for(int i = 0; i < PE; i++) begin - for(int j = 0; j < SIMD; j++) begin - if(curr_simd_C == j) begin - if(state_wr_C == ST_WR_0) - a_we[0][i][j] = '1; - else - a_we[1][i][j] = '1; - end - end - end + a_we[state_wr_C == ST_WR_1][curr_simd_C] = 1; curr_nf_N = (curr_nf_C == NF-1) ? 0 : curr_nf_C + 1; curr_simd_N = (curr_nf_C == NF-1) ? ((curr_simd_C == SIMD-1) ? 0 : curr_simd_C + 1) : curr_simd_C; @@ -295,29 +280,23 @@ assign ovld = vld_C; assign odat = odat_C; // ---------------------------------------------------------------------------- -// Matrix +// Weight RAMs // ---------------------------------------------------------------------------- -for(genvar i = 0; i < 2; i++) begin - for(genvar j = 0; j < PE; j++) begin - for(genvar k = 0; k < SIMD; k++) begin - ram_p_c #( - .ADDR_BITS(WGT_ADDR_BITS), - .DATA_BITS(RAM_BITS), - .RAM_STYLE("distributed") - ) inst_ram_tp_c ( - .clk(ap_clk), - .a_en(1'b1), - .a_we(a_we[i][j][k]), - .a_addr(a_addr[i]), - .b_en(ordy), - .b_addr(b_addr[i]), - .a_data_in(a_data_in[i][j][k]), - .a_data_out(), - .b_data_out(odat_ram[i][j][k]) - ); +for(genvar i = 0; i < 2; i++) begin : genBank + for(genvar k = 0; k < SIMD; k++) begin : genSimd + (* RAM_STYLE = RAM_STYLE *) + logic [PE-1:0][WEIGHT_WIDTH-1:0] Ram[2**WGT_ADDR_BITS]; + logic [PE-1:0][WEIGHT_WIDTH-1:0] RdReg; + + always_ff @(posedge ap_clk) begin + if(a_we[i][k]) Ram[a_addr[i]] <= idat; + if(ordy) begin + RdReg <= Ram[b_addr[i]]; + foreach(RdReg[p]) odat_ram[i][p][k] <= RdReg[p]; + end end - end -end + end : genSimd +end : genBank endmodule : dynamic_load diff --git a/finn-rtllib/fetch_weights/fetch_weights.sv b/finn-rtllib/fetch_weights/fetch_weights.sv new file mode 100644 index 0000000000..bad9304c79 --- /dev/null +++ b/finn-rtllib/fetch_weights/fetch_weights.sv @@ -0,0 +1,328 @@ +/****************************************************************************** + * Copyright (C) 2024, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + *****************************************************************************/ + +module fetch_weights #( + int unsigned PE, + int unsigned SIMD, + int unsigned TH = 1, + int unsigned MH, + int unsigned MW, + int unsigned N_REPS, + int unsigned WEIGHT_WIDTH = 8, + + int unsigned ADDR_BITS = 64, + int unsigned DATA_BITS = 256, + int unsigned LEN_BITS = 32, + int unsigned IDX_BITS = 16, + + int unsigned N_LAYERS, + + int unsigned QDEPTH = 8, + int unsigned EN_OREG = 1, + int unsigned N_DCPL_STGS = 1, + int unsigned DBG = 0, + + // Safely deducible parameters + int unsigned IWSIMD = (TH > 1)? ((PE*SIMD)/TH) : SIMD, + int unsigned OWSIMD = (PE * SIMD) / TH, + int unsigned DS_BITS_BA = (IWSIMD*WEIGHT_WIDTH+7)/8 * 8, + int unsigned WS_BITS_BA = (OWSIMD*WEIGHT_WIDTH+7)/8 * 8, + logic[ADDR_BITS-1:0] LAYER_OFFS = ((MH*MW*WEIGHT_WIDTH+7)/8 + (DATA_BITS/8-1)) & ~(DATA_BITS/8-1) // AXI bus-width aligned +)( + input logic aclk, + input logic aresetn, + + output logic m_done, + + // AXI + output logic[ADDR_BITS-1:0] m_axi_ddr_araddr, + output logic[1:0] m_axi_ddr_arburst, + output logic[3:0] m_axi_ddr_arcache, + output logic[1:0] m_axi_ddr_arid, + output logic[7:0] m_axi_ddr_arlen, + output logic[0:0] m_axi_ddr_arlock, + output logic[2:0] m_axi_ddr_arprot, + output logic[2:0] m_axi_ddr_arsize, + input logic m_axi_ddr_arready, + output logic m_axi_ddr_arvalid, + output logic[ADDR_BITS-1:0] m_axi_ddr_awaddr, + output logic[1:0] m_axi_ddr_awburst, + output logic[3:0] m_axi_ddr_awcache, + output logic[1:0] m_axi_ddr_awid, + output logic[7:0] m_axi_ddr_awlen, + output logic[0:0] m_axi_ddr_awlock, + output logic[2:0] m_axi_ddr_awprot, + output logic[2:0] m_axi_ddr_awsize, + input logic m_axi_ddr_awready, + output logic m_axi_ddr_awvalid, + input logic[DATA_BITS-1:0] m_axi_ddr_rdata, + input logic[1:0] m_axi_ddr_rid, + input logic m_axi_ddr_rlast, + input logic[1:0] m_axi_ddr_rresp, + output logic m_axi_ddr_rready, + input logic m_axi_ddr_rvalid, + output logic[DATA_BITS-1:0] m_axi_ddr_wdata, + output logic m_axi_ddr_wlast, + output logic[DATA_BITS/8-1:0] m_axi_ddr_wstrb, + input logic m_axi_ddr_wready, + output logic m_axi_ddr_wvalid, + input logic[1:0] m_axi_ddr_bid, + input logic[1:0] m_axi_ddr_bresp, + output logic m_axi_ddr_bready, + input logic m_axi_ddr_bvalid, + + // Index + input logic s_idx_tvalid, + output logic s_idx_tready, + input logic[IDX_BITS-1:0] s_idx_tdata, + + // DMA stream out (to external width converter) + output logic axis_dma_tvalid, + input logic axis_dma_tready, + output logic[DATA_BITS-1:0] axis_dma_tdata, + output logic[DATA_BITS/8-1:0] axis_dma_tkeep, + output logic axis_dma_tlast, + + // DWC stream in (from external width converter) + input logic axis_dwc_tvalid, + output logic axis_dwc_tready, + input logic[DS_BITS_BA-1:0] axis_dwc_tdata, + input logic[(DS_BITS_BA)/8-1:0] axis_dwc_tkeep, + input logic axis_dwc_tlast, + + // Stream + output logic m_axis_tvalid, + input logic m_axis_tready, + output logic[WS_BITS_BA-1:0] m_axis_tdata +); + + //=== Layer Offsets ===================================================== + logic [N_LAYERS-1:0][ADDR_BITS-1:0] l_offsets; + for(genvar i = 0; i < N_LAYERS; i++) begin : genOffs + assign l_offsets[i] = i * LAYER_OFFS; + end : genOffs + + //=== Index Handling & DMA Control ====================================== + logic dma_tvalid; + logic dma_tready; + logic [ADDR_BITS-1:0] dma_addr; + logic [ LEN_BITS-1:0] dma_len; + + if(TH > 1) begin : genTiled + + localparam int unsigned REPS_BITS = (N_REPS == 1)? 1 : $clog2(N_REPS); + + typedef enum logic [0:0] {ST_IDLE, ST_DMA} state_e; + + //--- Registers ----------------------------------------------------- + state_e State = ST_IDLE; + state_e state_n; + + logic [REPS_BITS-1:0] CntDma = '0; + logic [REPS_BITS-1:0] cnt_dma_n; + + logic [IDX_BITS-1:0] Idx = '0; + logic [IDX_BITS-1:0] idx_n; + + //--- Index Queue --------------------------------------------------- + uwire q_idx_vld; + logic q_idx_rdy; + uwire [IDX_BITS-1:0] q_idx_dat; + + Q_srl #(.depth(QDEPTH), .width(IDX_BITS)) inst_queue_in ( + .clock(aclk), .reset(!aresetn), + .count(), .maxcount(), + .i_d(s_idx_tdata), .i_v(s_idx_tvalid), .i_r(s_idx_tready), + .o_d(q_idx_dat), .o_v(q_idx_vld), .o_r(q_idx_rdy) + ); + + assign dma_addr = l_offsets[Idx]; + assign dma_len = ((MH*MW*WEIGHT_WIDTH+7)/8) & ~7; + + //--- Sequential ---------------------------------------------------- + always_ff @(posedge aclk) begin + if(~aresetn) begin + State <= ST_IDLE; + CntDma <= '0; + Idx <= 'x; + end + else begin + State <= state_n; + CntDma <= cnt_dma_n; + Idx <= idx_n; + end + end + + //--- Next State ---------------------------------------------------- + always_comb begin + state_n = State; + + case(State) + ST_IDLE: + state_n = q_idx_vld? ST_DMA : ST_IDLE; + + ST_DMA: + state_n = ((CntDma == N_REPS-1) && dma_tready)? ST_IDLE : ST_DMA; + endcase + end + + //--- Datapath ------------------------------------------------------ + always_comb begin + cnt_dma_n = CntDma; + idx_n = Idx; + + q_idx_rdy = 0; + dma_tvalid = 0; + + case(State) + ST_IDLE: begin + q_idx_rdy = 1; + cnt_dma_n = 0; + if(q_idx_vld) + idx_n = q_idx_dat; + end + + ST_DMA: begin + dma_tvalid = 1; + if(dma_tready) + cnt_dma_n = CntDma + 1; + end + endcase + end + + end : genTiled + else begin : genDirect + + uwire [IDX_BITS-1:0] q_idx_dat; + + Q_srl #(.depth(QDEPTH), .width(IDX_BITS)) inst_idx_queue ( + .clock(aclk), .reset(!aresetn), + .count(), .maxcount(), + .i_d(s_idx_tdata), .i_v(s_idx_tvalid), .i_r(s_idx_tready), + .o_d(q_idx_dat), .o_v(dma_tvalid), .o_r(dma_tready) + ); + + assign dma_addr = l_offsets[q_idx_dat]; + assign dma_len = ((MH*MW*WEIGHT_WIDTH+7)/8) & ~7; + + end : genDirect + + //=== Write Channel Tie-off (read-only DMA) ============================= + assign m_axi_ddr_awaddr = '0; + assign m_axi_ddr_awburst = '0; + assign m_axi_ddr_awcache = '0; + assign m_axi_ddr_awid = '0; + assign m_axi_ddr_awlen = '0; + assign m_axi_ddr_awlock = '0; + assign m_axi_ddr_awprot = '0; + assign m_axi_ddr_awsize = '0; + assign m_axi_ddr_awvalid = 0; + assign m_axi_ddr_wdata = '0; + assign m_axi_ddr_wlast = 0; + assign m_axi_ddr_wstrb = '0; + assign m_axi_ddr_wvalid = 0; + assign m_axi_ddr_bready = 0; + + //=== DMA Engine ======================================================== + cdma_u_rd #( + .DATA_BITS(DATA_BITS), + .ADDR_BITS(ADDR_BITS), + .LEN_BITS(LEN_BITS) + ) inst_dma ( + .aclk(aclk), .aresetn(aresetn), + + .rd_valid(dma_tvalid), .rd_ready(dma_tready), + .rd_paddr(dma_addr), .rd_len(dma_len), + .rd_done(m_done), + + .m_axi_ddr_arvalid(m_axi_ddr_arvalid), + .m_axi_ddr_arready(m_axi_ddr_arready), + .m_axi_ddr_araddr(m_axi_ddr_araddr), + .m_axi_ddr_arid(m_axi_ddr_arid), + .m_axi_ddr_arlen(m_axi_ddr_arlen), + .m_axi_ddr_arsize(m_axi_ddr_arsize), + .m_axi_ddr_arburst(m_axi_ddr_arburst), + .m_axi_ddr_arlock(m_axi_ddr_arlock), + .m_axi_ddr_arcache(m_axi_ddr_arcache), + .m_axi_ddr_arprot(m_axi_ddr_arprot), + .m_axi_ddr_rvalid(m_axi_ddr_rvalid), + .m_axi_ddr_rready(m_axi_ddr_rready), + .m_axi_ddr_rdata(m_axi_ddr_rdata), + .m_axi_ddr_rlast(m_axi_ddr_rlast), + .m_axi_ddr_rid(m_axi_ddr_rid), + .m_axi_ddr_rresp(m_axi_ddr_rresp), + + .m_axis_ddr_tvalid(axis_dma_tvalid), + .m_axis_ddr_tready(axis_dma_tready), + .m_axis_ddr_tdata(axis_dma_tdata), + .m_axis_ddr_tkeep(axis_dma_tkeep), + .m_axis_ddr_tlast(axis_dma_tlast) + ); + + //=== Local Weight Buffer =============================================== + logic axis_lwb_tvalid; + logic axis_lwb_tready; + logic [WS_BITS_BA-1:0] axis_lwb_tdata; + + if(TH == 1) begin : genLwb + local_weight_buffer #( + .PE(PE), .SIMD(SIMD), .MH(MH), .MW(MW), + .N_REPS(N_REPS), .WEIGHT_WIDTH(WEIGHT_WIDTH), .DBG(DBG) + ) inst_weight_buff ( + .clk(aclk), .rst(~aresetn), + .ivld(axis_dwc_tvalid), .irdy(axis_dwc_tready), .idat(axis_dwc_tdata), + .ovld(axis_lwb_tvalid), .ordy(axis_lwb_tready), .odat(axis_lwb_tdata) + ); + end : genLwb + else begin : genLwbPassthru + assign axis_lwb_tvalid = axis_dwc_tvalid; + assign axis_dwc_tready = axis_lwb_tready; + assign axis_lwb_tdata = axis_dwc_tdata; + end : genLwbPassthru + + //=== Output Register Slice ============================================= + if(EN_OREG) begin : genOreg + skid #( + .DATA_WIDTH(WS_BITS_BA), .FEED_STAGES(N_DCPL_STGS) + ) inst_oreg ( + .clk(aclk), .rst(!aresetn), + .ivld(axis_lwb_tvalid), .irdy(axis_lwb_tready), .idat(axis_lwb_tdata), + .ovld(m_axis_tvalid), .ordy(m_axis_tready), .odat(m_axis_tdata) + ); + end : genOreg + else begin : genOregPassthru + assign m_axis_tvalid = axis_lwb_tvalid; + assign axis_lwb_tready = m_axis_tready; + assign m_axis_tdata = axis_lwb_tdata; + end : genOregPassthru + +endmodule : fetch_weights diff --git a/finn-rtllib/mlo/fetch_weights_wrapper.v b/finn-rtllib/fetch_weights/fetch_weights_wrapper.v similarity index 79% rename from finn-rtllib/mlo/fetch_weights_wrapper.v rename to finn-rtllib/fetch_weights/fetch_weights_wrapper.v index dc92478b6c..cf79afb7c3 100644 --- a/finn-rtllib/mlo/fetch_weights_wrapper.v +++ b/finn-rtllib/fetch_weights/fetch_weights_wrapper.v @@ -28,13 +28,17 @@ * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * + * @brief Verilog AXI-lite wrapper for MVU & VVU. *****************************************************************************/ +`define $EN_MLO$ + module $MODULE_NAME_AXI_WRAPPER$ #( parameter MW = $MW$, parameter MH = $MH$, parameter PE = $PE$, parameter SIMD = $SIMD$, + parameter TH = $TH$, parameter N_REPS = $N_REPS$, parameter WEIGHT_WIDTH = $WEIGHT_WIDTH$, parameter N_LAYERS = $N_LAYERS$, @@ -45,7 +49,10 @@ module $MODULE_NAME_AXI_WRAPPER$ #( parameter IDX_BITS = 16, // Safely deducible parameters - parameter WS_BITS_BA = (PE*SIMD*WEIGHT_WIDTH+7)/8 * 8 + parameter IWSIMD = $IWSIMD$, + parameter WSIMD = $WSIMD$, + parameter DS_BITS_BA = (IWSIMD*WEIGHT_WIDTH+7)/8 * 8, + parameter WS_BITS_BA = (WSIMD*WEIGHT_WIDTH+7)/8 * 8 )( // Global Control (* X_INTERFACE_PARAMETER = "ASSOCIATED_BUSIF axi_mm:in_idx0_V:out0_V, ASSOCIATED_RESET ap_rst_n" *) @@ -95,10 +102,12 @@ module $MODULE_NAME_AXI_WRAPPER$ #( output wire axi_mm_bready, input wire axi_mm_bvalid, +`ifdef EN_MLO // Index input wire in_idx0_V_tvalid, output wire in_idx0_V_tready, input wire[IDX_BITS-1:0] in_idx0_V_tdata, +`endif // Stream output wire out0_V_tvalid, @@ -106,15 +115,46 @@ module $MODULE_NAME_AXI_WRAPPER$ #( output wire[WS_BITS_BA-1:0] out0_V_tdata ); +`ifndef EN_MLO + wire in_idx0_V_tvalid; + wire in_idx0_V_tready; + wire [IDX_BITS-1:0] in_idx0_V_tdata; + + assign in_idx0_V_tvalid = 1'b1; + assign in_idx0_V_tdata = 0; +`endif + +// DMA <-> DWC internal wires +wire axis_dma_tvalid; +wire axis_dma_tready; +wire [DATA_BITS-1:0] axis_dma_tdata; +wire [DATA_BITS/8-1:0] axis_dma_tkeep; +wire axis_dma_tlast; + +wire axis_dwc_tvalid; +wire axis_dwc_tready; +wire [DS_BITS_BA-1:0] axis_dwc_tdata; +wire [(DS_BITS_BA)/8-1:0] axis_dwc_tkeep; +wire axis_dwc_tlast; + +// Width converter +$DWC_MODULE_NAME$ inst_dwc ( + .aclk(ap_clk), .aresetn(ap_rst_n), + .s_axis_tvalid(axis_dma_tvalid), .s_axis_tready(axis_dma_tready), .s_axis_tdata(axis_dma_tdata), .s_axis_tkeep(axis_dma_tkeep), .s_axis_tlast(axis_dma_tlast), + .m_axis_tvalid(axis_dwc_tvalid), .m_axis_tready(axis_dwc_tready), .m_axis_tdata(axis_dwc_tdata), .m_axis_tkeep(axis_dwc_tkeep), .m_axis_tlast(axis_dwc_tlast) +); fetch_weights #( - .PE(PE), .SIMD(SIMD), .MH(MH), .MW(MW), .N_REPS(N_REPS), + .PE(PE), .SIMD(SIMD), .TH(TH), + .MH(MH), .MW(MW), .N_REPS(N_REPS), .WEIGHT_WIDTH(WEIGHT_WIDTH), + .IWSIMD(IWSIMD), .OWSIMD(WSIMD), .ADDR_BITS(ADDR_BITS), .DATA_BITS(DATA_BITS), .LEN_BITS(LEN_BITS), .IDX_BITS(IDX_BITS), .N_LAYERS(N_LAYERS) ) inst ( .aclk (ap_clk), .aresetn (ap_rst_n), + .m_done (out_done), .m_axi_ddr_araddr (axi_mm_araddr), .m_axi_ddr_arburst (axi_mm_arburst), @@ -156,6 +196,18 @@ fetch_weights #( .s_idx_tready (in_idx0_V_tready), .s_idx_tdata (in_idx0_V_tdata), + .axis_dma_tvalid (axis_dma_tvalid), + .axis_dma_tready (axis_dma_tready), + .axis_dma_tdata (axis_dma_tdata), + .axis_dma_tkeep (axis_dma_tkeep), + .axis_dma_tlast (axis_dma_tlast), + + .axis_dwc_tvalid (axis_dwc_tvalid), + .axis_dwc_tready (axis_dwc_tready), + .axis_dwc_tdata (axis_dwc_tdata), + .axis_dwc_tkeep (axis_dwc_tkeep), + .axis_dwc_tlast (axis_dwc_tlast), + .m_axis_tvalid (out0_V_tvalid), .m_axis_tready (out0_V_tready), .m_axis_tdata (out0_V_tdata) diff --git a/finn-rtllib/fetch_weights/local_weight_buffer.sv b/finn-rtllib/fetch_weights/local_weight_buffer.sv new file mode 100644 index 0000000000..71dbc14024 --- /dev/null +++ b/finn-rtllib/fetch_weights/local_weight_buffer.sv @@ -0,0 +1,273 @@ +/****************************************************************************** + * Copyright (C) 2024, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + *****************************************************************************/ + +module local_weight_buffer #( + int unsigned PE, + int unsigned SIMD, + int unsigned WEIGHT_WIDTH = 8, + int unsigned MH, + int unsigned MW, + int unsigned N_REPS, + int unsigned DBG = 0, + parameter RAM_STYLE = "block" +)( + input logic clk, + input logic rst, + + input logic ivld, + output logic irdy, + input logic [SIMD-1:0][WEIGHT_WIDTH-1:0] idat, + + output logic ovld, + input logic ordy, + output logic [PE-1:0][SIMD-1:0][WEIGHT_WIDTH-1:0] odat +); + + //=== Constants and Types =============================================== + localparam int unsigned SF = MW / SIMD; + localparam int unsigned NF = MH / PE; + localparam int unsigned N_TLS = SF * NF; + localparam int unsigned PE_BITS = (PE == 1)? 1 : $clog2(PE); + localparam int unsigned WGT_ADDR_BITS = $clog2(NF * SF); + localparam int unsigned N_TLS_BITS = $clog2(N_TLS); + localparam int unsigned N_REPS_BITS = $clog2(N_REPS); + + typedef enum logic [1:0] {ST_WR_0, ST_WR_0_WAIT, ST_WR_1, ST_WR_1_WAIT} state_wr_e; + typedef enum logic {ST_RD_0, ST_RD_1} state_rd_e; + + //=== Writer ============================================================ + + //--- Registers --------------------------------------------------------- + state_wr_e StateWr = ST_WR_0; + state_wr_e state_wr_n; + state_rd_e StateRd = ST_RD_0; + state_rd_e state_rd_n; + + logic [N_TLS_BITS-1:0] WrPntr = '0; + logic [N_TLS_BITS-1:0] wr_pntr_n; + + logic [PE_BITS-1:0] CurrPe = '0; + logic [PE_BITS-1:0] curr_pe_n; + + //--- Signals ----------------------------------------------------------- + logic [1:0][PE-1:0] a_we; + logic [1:0][WGT_ADDR_BITS-1:0] a_addr; + logic [1:0][SIMD-1:0][WEIGHT_WIDTH-1:0] a_data_in; + + //--- Sequential -------------------------------------------------------- + always_ff @(posedge clk) begin + if(rst) begin + StateWr <= ST_WR_0; + WrPntr <= '0; + CurrPe <= '0; + end + else begin + StateWr <= state_wr_n; + WrPntr <= wr_pntr_n; + CurrPe <= curr_pe_n; + end + end + + //--- Next State -------------------------------------------------------- + always_comb begin + state_wr_n = StateWr; + + case(StateWr) + ST_WR_0: + if((CurrPe == PE-1) && (WrPntr == N_TLS-1) && ivld) + state_wr_n = (StateRd == ST_RD_0)? ST_WR_1 : ST_WR_0_WAIT; + + ST_WR_0_WAIT: + state_wr_n = (StateRd == ST_RD_0)? ST_WR_1 : ST_WR_0_WAIT; + + ST_WR_1: + if((CurrPe == PE-1) && (WrPntr == N_TLS-1) && ivld) + state_wr_n = (StateRd == ST_RD_1)? ST_WR_0 : ST_WR_1_WAIT; + + ST_WR_1_WAIT: + state_wr_n = (StateRd == ST_RD_1)? ST_WR_0 : ST_WR_1_WAIT; + endcase + end + + //--- Datapath ---------------------------------------------------------- + always_comb begin + wr_pntr_n = WrPntr; + curr_pe_n = CurrPe; + + irdy = 0; + + a_we = '0; + for(int i = 0; i < 2; i++) begin + a_addr[i] = WrPntr; + a_data_in[i] = idat; + end + + case(StateWr) + ST_WR_0, ST_WR_1: begin + irdy = 1; + + if(ivld) begin + for(int i = 0; i < PE; i++) + if(CurrPe == i) + a_we[StateWr == ST_WR_1][i] = 1; + + curr_pe_n = (CurrPe == PE-1)? 0 : CurrPe + 1; + wr_pntr_n = (CurrPe == PE-1)? ((WrPntr == N_TLS-1)? 0 : WrPntr + 1) : WrPntr; + end + end + endcase + end + + //=== Reader ============================================================ + + //--- Registers --------------------------------------------------------- + logic [N_TLS_BITS-1:0] RdPntr = '0; + logic [N_TLS_BITS-1:0] rd_pntr_n; + + logic [N_REPS_BITS-1:0] Reps = '0; + logic [N_REPS_BITS-1:0] reps_n; + + logic [1:0] VldS0 = '0; + logic [1:0] vld_s0_n; + + logic [1:0] VldS1 = '0; + logic [1:0] vld_s1_n; + + logic Vld = 0; + logic vld_n; + + logic [PE-1:0][SIMD-1:0][WEIGHT_WIDTH-1:0] Odat = '0; + logic [PE-1:0][SIMD-1:0][WEIGHT_WIDTH-1:0] odat_n; + + //--- Signals ----------------------------------------------------------- + logic [1:0][WGT_ADDR_BITS-1:0] b_addr; + logic [1:0][PE-1:0][SIMD-1:0][WEIGHT_WIDTH-1:0] odat_ram; + + //--- Sequential -------------------------------------------------------- + always_ff @(posedge clk) begin + if(rst) begin + StateRd <= ST_RD_0; + RdPntr <= '0; + Reps <= '0; + VldS0 <= '0; + VldS1 <= '0; + Vld <= 0; + Odat <= 'x; + end + else begin + StateRd <= state_rd_n; + RdPntr <= rd_pntr_n; + Reps <= reps_n; + VldS0 <= vld_s0_n; + VldS1 <= vld_s1_n; + Vld <= vld_n; + Odat <= odat_n; + end + end + + //--- Next State -------------------------------------------------------- + always_comb begin + state_rd_n = StateRd; + + case(StateRd) + ST_RD_0: + if(ordy && ((StateWr == ST_WR_0)? (WrPntr > RdPntr) : 1)) + if((RdPntr == N_TLS-1) && (Reps == N_REPS-1)) + state_rd_n = ST_RD_1; + + ST_RD_1: + if(ordy && ((StateWr == ST_WR_1)? (WrPntr > RdPntr) : 1)) + if((RdPntr == N_TLS-1) && (Reps == N_REPS-1)) + state_rd_n = ST_RD_0; + endcase + end + + //--- Datapath ---------------------------------------------------------- + always_comb begin + rd_pntr_n = RdPntr; + reps_n = Reps; + + for(int i = 0; i < 2; i++) begin + vld_s0_n[i] = ordy? 0 : VldS0[i]; + vld_s1_n[i] = ordy? VldS0[i] : VldS1[i]; + end + + vld_n = ordy? |VldS1 : Vld; + odat_n = ordy? (VldS1[0]? odat_ram[0] : odat_ram[1]) : Odat; + + for(int i = 0; i < 2; i++) + b_addr[i] = RdPntr; + + case(StateRd) + ST_RD_0: begin + if(ordy) begin + if((StateWr == ST_WR_0)? (WrPntr > RdPntr) : 1) begin + vld_s0_n[0] = 1; + rd_pntr_n = (RdPntr == N_TLS-1)? 0 : RdPntr + 1; + reps_n = (RdPntr == N_TLS-1)? ((Reps == N_REPS-1)? 0 : Reps + 1) : Reps; + end + end + end + + ST_RD_1: begin + if(ordy) begin + if((StateWr == ST_WR_1)? (WrPntr > RdPntr) : 1) begin + vld_s0_n[1] = 1; + rd_pntr_n = (RdPntr == N_TLS-1)? 0 : RdPntr + 1; + reps_n = (RdPntr == N_TLS-1)? ((Reps == N_REPS-1)? 0 : Reps + 1) : Reps; + end + end + end + endcase + end + + assign ovld = Vld; + assign odat = Odat; + + //=== Weight RAMs ======================================================= + for(genvar i = 0; i < 2; i++) begin : genBank + for(genvar j = 0; j < PE; j++) begin : genPe + (* RAM_STYLE = RAM_STYLE *) + logic [SIMD-1:0][WEIGHT_WIDTH-1:0] Ram[2**WGT_ADDR_BITS]; + logic [SIMD-1:0][WEIGHT_WIDTH-1:0] RdReg; + + always_ff @(posedge clk) begin + if(a_we[i][j]) Ram[a_addr[i]] <= a_data_in[i]; + if(ordy) begin + RdReg <= Ram[b_addr[i]]; + odat_ram[i][j] <= RdReg; + end + end + end : genPe + end : genBank + +endmodule : local_weight_buffer diff --git a/finn-rtllib/mlo/fetch_weights.sv b/finn-rtllib/mlo/fetch_weights.sv deleted file mode 100644 index fda40e45d8..0000000000 --- a/finn-rtllib/mlo/fetch_weights.sv +++ /dev/null @@ -1,226 +0,0 @@ -/****************************************************************************** - * Copyright (C) 2024, Advanced Micro Devices, Inc. - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, - * this list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, - * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR - * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR - * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, - * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, - * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, - * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR - * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF - * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - *****************************************************************************/ - -module fetch_weights #( - int unsigned PE, - int unsigned SIMD, - int unsigned MH, - int unsigned MW, - int unsigned N_REPS, - int unsigned WEIGHT_WIDTH = 8, - - int unsigned ADDR_BITS = 64, - int unsigned DATA_BITS = 256, - int unsigned LEN_BITS = 32, - int unsigned IDX_BITS = 16, - - int unsigned N_LAYERS, - - int unsigned QDEPTH = 8, - int unsigned EN_OREG = 1, - int unsigned N_DCPL_STGS = 1, - int unsigned DBG = 0, - - // Safely deducible parameters - int unsigned DS_BITS_BA = (SIMD*WEIGHT_WIDTH+7)/8 * 8, - int unsigned WS_BITS_BA = (PE*SIMD*WEIGHT_WIDTH+7)/8 * 8, - logic[ADDR_BITS-1:0] LAYER_OFFS = ((MH*MW*WEIGHT_WIDTH+7)/8) & ~7 // 8-byte aligned -) ( - input logic aclk, - input logic aresetn, - - output logic m_done, - - // AXI - output logic[ADDR_BITS-1:0] m_axi_ddr_araddr, - output logic[1:0] m_axi_ddr_arburst, - output logic[3:0] m_axi_ddr_arcache, - output logic[1:0] m_axi_ddr_arid, - output logic[7:0] m_axi_ddr_arlen, - output logic[0:0] m_axi_ddr_arlock, - output logic[2:0] m_axi_ddr_arprot, - output logic[2:0] m_axi_ddr_arsize, - input logic m_axi_ddr_arready, - output logic m_axi_ddr_arvalid, - output logic[ADDR_BITS-1:0] m_axi_ddr_awaddr, - output logic[1:0] m_axi_ddr_awburst, - output logic[3:0] m_axi_ddr_awcache, - output logic[1:0] m_axi_ddr_awid, - output logic[7:0] m_axi_ddr_awlen, - output logic[0:0] m_axi_ddr_awlock, - output logic[2:0] m_axi_ddr_awprot, - output logic[2:0] m_axi_ddr_awsize, - input logic m_axi_ddr_awready, - output logic m_axi_ddr_awvalid, - input logic[DATA_BITS-1:0] m_axi_ddr_rdata, - input logic[1:0] m_axi_ddr_rid, - input logic m_axi_ddr_rlast, - input logic[1:0] m_axi_ddr_rresp, - output logic m_axi_ddr_rready, - input logic m_axi_ddr_rvalid, - output logic[DATA_BITS-1:0] m_axi_ddr_wdata, - output logic m_axi_ddr_wlast, - output logic[DATA_BITS/8-1:0] m_axi_ddr_wstrb, - input logic m_axi_ddr_wready, - output logic m_axi_ddr_wvalid, - input logic[1:0] m_axi_ddr_bid, - input logic[1:0] m_axi_ddr_bresp, - output logic m_axi_ddr_bready, - input logic m_axi_ddr_bvalid, - - // Index - input logic s_idx_tvalid, - output logic s_idx_tready, - input logic[IDX_BITS-1:0] s_idx_tdata, - - // Stream - // TODO: Should we reg this? Would be quite wide ... - output logic m_axis_tvalid, - input logic m_axis_tready, - output logic[WS_BITS_BA-1:0] m_axis_tdata -); - -localparam int unsigned WMAT_SIZE = ((MH*MW*WEIGHT_WIDTH+7)/8) & ~7; - -// Offsets -logic [N_LAYERS-1:0][ADDR_BITS-1:0] l_offsets; -for(genvar i = 0; i < N_LAYERS; i++) begin - assign l_offsets[i] = (i * LAYER_OFFS); -end - -logic q_idx_out_tvalid, q_idx_out_tready; -logic [IDX_BITS-1:0] q_idx_out_tdata; -logic [ADDR_BITS-1:0] q_dma_addr; -logic [LEN_BITS-1:0] q_dma_len; - -// Queues -Q_srl #( - .depth(QDEPTH), - .width(IDX_BITS) -) inst_queue_in ( - .clock(aclk), .reset(!aresetn), - .count(), .maxcount(), - .i_d(s_idx_tdata), .i_v(s_idx_tvalid), .i_r(s_idx_tready), - .o_d(q_idx_out_tdata), .o_v(q_idx_out_tvalid), .o_r(q_idx_out_tready) -); - -assign q_dma_addr = l_offsets[q_idx_out_tdata]; -assign q_dma_len = WMAT_SIZE; - -// DMA -logic axis_dma_tvalid; -logic axis_dma_tready; -logic[DATA_BITS-1:0] axis_dma_tdata; -logic[DATA_BITS/8-1:0] axis_dma_tkeep; -logic axis_dma_tlast; - -cdma_u_rd #( - .DATA_BITS(DATA_BITS), - .ADDR_BITS(ADDR_BITS), - .LEN_BITS(LEN_BITS) -) inst_dma ( - .aclk(aclk), .aresetn(aresetn), - - .rd_valid(q_idx_out_tvalid), .rd_ready(q_idx_out_tready), - .rd_paddr(q_dma_addr), .rd_len(q_dma_len), - .rd_done(m_done), - - .m_axi_ddr_arvalid(m_axi_ddr_arvalid), - .m_axi_ddr_arready(m_axi_ddr_arready), - .m_axi_ddr_araddr(m_axi_ddr_araddr), - .m_axi_ddr_arid(m_axi_ddr_arid), - .m_axi_ddr_arlen(m_axi_ddr_arlen), - .m_axi_ddr_arsize(m_axi_ddr_arsize), - .m_axi_ddr_arburst(m_axi_ddr_arburst), - .m_axi_ddr_arlock(m_axi_ddr_arlock), - .m_axi_ddr_arcache(m_axi_ddr_arcache), - .m_axi_ddr_arprot(m_axi_ddr_arprot), - .m_axi_ddr_rvalid(m_axi_ddr_rvalid), - .m_axi_ddr_rready(m_axi_ddr_rready), - .m_axi_ddr_rdata(m_axi_ddr_rdata), - .m_axi_ddr_rlast(m_axi_ddr_rlast), - .m_axi_ddr_rid(m_axi_ddr_rid), - .m_axi_ddr_rresp(m_axi_ddr_rresp), - - .m_axis_ddr_tvalid(axis_dma_tvalid), - .m_axis_ddr_tready(axis_dma_tready), - .m_axis_ddr_tdata(axis_dma_tdata), - .m_axis_ddr_tkeep(axis_dma_tkeep), - .m_axis_ddr_tlast(axis_dma_tlast) -); - -// Width conversion -logic axis_dwc_tvalid; -logic axis_dwc_tready; -logic[DS_BITS_BA-1:0] axis_dwc_tdata; -logic[(DS_BITS_BA)/8-1:0] axis_dwc_tkeep; -logic axis_dwc_tlast; - -axis_fifo_adapter #( - .S_DATA_WIDTH(DATA_BITS), .M_DATA_WIDTH(DS_BITS_BA) -) inst_dwc ( - .clk(aclk), .rst(~aresetn), - .pause_req('0), .s_axis_tid('0), .s_axis_tdest('0), .s_axis_tuser('0), - .s_axis_tvalid(axis_dma_tvalid), .s_axis_tready(axis_dma_tready), .s_axis_tdata(axis_dma_tdata), .s_axis_tkeep(axis_dma_tkeep), .s_axis_tlast(axis_dma_tlast), - .pause_ack(), .m_axis_tid(), .m_axis_tdest(), .m_axis_tuser(), - .m_axis_tvalid(axis_dwc_tvalid), .m_axis_tready(axis_dwc_tready), .m_axis_tdata(axis_dwc_tdata), .m_axis_tkeep(axis_dwc_tkeep), .m_axis_tlast(axis_dwc_tlast) -); - -// Double buffer -logic axis_lwb_tvalid; -logic axis_lwb_tready; -logic[WS_BITS_BA-1:0] axis_lwb_tdata; - -local_weight_buffer #( - .PE(PE), .SIMD(SIMD), .MH(MH), .MW(MW), .N_REPS(N_REPS), .WEIGHT_WIDTH(WEIGHT_WIDTH), .DBG(DBG) -) inst_weight_buff ( - .clk(aclk), .rst(~aresetn), - .ivld(axis_dwc_tvalid), .irdy(axis_dwc_tready), .idat(axis_dwc_tdata), - .ovld(axis_lwb_tvalid), .ordy(axis_lwb_tready), .odat(axis_lwb_tdata) -); - -// Reg slice -if(EN_OREG) begin - skid #( - .DATA_WIDTH(WS_BITS_BA), .FEED_STAGES(N_DCPL_STGS) - ) inst_oreg ( - .clk(aclk), .rst(~aresetn), - .ivld(axis_lwb_tvalid), .irdy(axis_lwb_tready), .idat(axis_lwb_tdata), - .ovld(m_axis_tvalid), .ordy(m_axis_tready), .odat(m_axis_tdata) - ); -end else begin - assign m_axis_tvalid = axis_lwb_tvalid; - assign axis_lwb_tready = m_axis_tready; - assign m_axis_tdata = axis_lwb_tdata; -end - -endmodule diff --git a/finn-rtllib/mlo/infrastructure/intermediate_frames.sv b/finn-rtllib/mlo/infrastructure/intermediate_frames.sv index 51e426ef2b..c924007955 100644 --- a/finn-rtllib/mlo/infrastructure/intermediate_frames.sv +++ b/finn-rtllib/mlo/infrastructure/intermediate_frames.sv @@ -410,42 +410,18 @@ end logic last_dwc_in; assign last_dwc_in = (cnt_dwc_C == FM_BEATS_IN-1); -axis_fifo_adapter #(.S_DATA_WIDTH(OLEN_BITS), .M_DATA_WIDTH(DATA_BITS)) inst_dwc_wr ( - .clk(aclk), - .rst(~aresetn), - - .pause_req('0), .s_axis_tid('0), .s_axis_tdest('0), .s_axis_tuser('0), - .s_axis_tvalid(s_axis_int_tvalid), - .s_axis_tready(s_axis_int_tready), - .s_axis_tdata (s_axis_int_tdata), - .s_axis_tkeep ('1), - .s_axis_tlast (last_dwc_in), - - .pause_ack(), .m_axis_tid(), .m_axis_tdest(), .m_axis_tuser(), - .m_axis_tvalid(axis_dma_wr_tvalid), - .m_axis_tready(axis_dma_wr_tready), - .m_axis_tdata (axis_dma_wr_tdata), - .m_axis_tkeep (axis_dma_wr_tkeep), - .m_axis_tlast (axis_dma_wr_tlast) +// DWC write: OLEN_BITS -> DATA_BITS (body output -> DMA) +if_dwc_sink inst_dwc_wr ( + .aclk(aclk), .aresetn(aresetn), + .s_axis_tvalid(s_axis_int_tvalid), .s_axis_tready(s_axis_int_tready), .s_axis_tdata(s_axis_int_tdata), .s_axis_tkeep({(OLEN_BITS/8){1'b1}}), .s_axis_tlast(last_dwc_in), + .m_axis_tvalid(axis_dma_wr_tvalid), .m_axis_tready(axis_dma_wr_tready), .m_axis_tdata(axis_dma_wr_tdata), .m_axis_tkeep(axis_dma_wr_tkeep), .m_axis_tlast(axis_dma_wr_tlast) ); -axis_fifo_adapter #(.S_DATA_WIDTH(DATA_BITS), .M_DATA_WIDTH(ILEN_BITS)) inst_dwc_rd ( - .clk(aclk), - .rst(~aresetn), - - .pause_req('0), .s_axis_tid('0), .s_axis_tdest('0), .s_axis_tuser('0), - .s_axis_tvalid(axis_dma_rd_tvalid), - .s_axis_tready(axis_dma_rd_tready), - .s_axis_tdata (axis_dma_rd_tdata), - .s_axis_tkeep (axis_dma_rd_tkeep), - .s_axis_tlast (axis_dma_rd_tlast), - - .pause_ack(), .m_axis_tid(), .m_axis_tdest(), .m_axis_tuser(), - .m_axis_tvalid(m_axis_int_tvalid), - .m_axis_tready(m_axis_int_tready), - .m_axis_tdata (m_axis_int_tdata), - .m_axis_tkeep (), - .m_axis_tlast () +// DWC read: DATA_BITS -> ILEN_BITS (DMA -> body input) +if_dwc_source inst_dwc_rd ( + .aclk(aclk), .aresetn(aresetn), + .s_axis_tvalid(axis_dma_rd_tvalid), .s_axis_tready(axis_dma_rd_tready), .s_axis_tdata(axis_dma_rd_tdata), .s_axis_tkeep(axis_dma_rd_tkeep), .s_axis_tlast(axis_dma_rd_tlast), + .m_axis_tvalid(m_axis_int_tvalid), .m_axis_tready(m_axis_int_tready), .m_axis_tdata(m_axis_int_tdata), .m_axis_tkeep(), .m_axis_tlast() ); // REG diff --git a/finn-rtllib/mlo/local_weight_buffer.sv b/finn-rtllib/mlo/local_weight_buffer.sv deleted file mode 100644 index cdc2a9eca3..0000000000 --- a/finn-rtllib/mlo/local_weight_buffer.sv +++ /dev/null @@ -1,305 +0,0 @@ -/****************************************************************************** - * Copyright (C) 2024, Advanced Micro Devices, Inc. - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, - * this list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, - * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR - * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR - * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, - * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, - * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, - * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR - * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF - * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - *****************************************************************************/ - -module local_weight_buffer #( - int unsigned PE, - int unsigned SIMD, - int unsigned WEIGHT_WIDTH, - int unsigned MH, - int unsigned MW, - int unsigned N_REPS, - int unsigned DBG = 0 -) ( - input logic clk, - input logic rst, - - input logic ivld, - output logic irdy, - input logic [SIMD-1:0][WEIGHT_WIDTH-1:0] idat, - - output logic ovld, - input logic ordy, - output logic [PE-1:0][SIMD-1:0][WEIGHT_WIDTH-1:0] odat -); - -// ---------------------------------------------------------------------------- -// Consts and types -// ---------------------------------------------------------------------------- - -localparam int unsigned SF = MW/SIMD; -localparam int unsigned NF = MH/PE; -localparam int unsigned N_TLS = SF * NF; - -localparam int unsigned SIMD_BITS = (SIMD == 1) ? 1 : $clog2(SIMD); -localparam int unsigned PE_BITS = (PE == 1) ? 1 : $clog2(PE); -localparam int unsigned WGT_ADDR_BITS = $clog2(NF * SF); -localparam int unsigned RAM_BITS = (SIMD*WEIGHT_WIDTH + 7)/8 * 8; -localparam int unsigned WGT_EN_BITS = RAM_BITS / 8; -localparam int unsigned N_TLS_BITS = $clog2(N_TLS); -localparam int unsigned N_REPS_BITS = $clog2(N_REPS); - -typedef enum logic[1:0] {ST_WR_0, ST_WR_0_WAIT, ST_WR_1, ST_WR_1_WAIT} state_wr_t; -typedef enum logic {ST_RD_0, ST_RD_1} state_rd_t; - -// ---------------------------------------------------------------------------- -// Writer -// ---------------------------------------------------------------------------- - -// -- Regs -state_wr_t state_wr_C = ST_WR_0, state_wr_N; -state_rd_t state_rd_C = ST_RD_0, state_rd_N; - -logic[N_TLS_BITS-1:0] wr_pntr_C = '0, wr_pntr_N; -logic[PE_BITS-1:0] curr_pe_C = '0, curr_pe_N; - -// -- Signals -logic [1:0][PE-1:0][WGT_EN_BITS-1:0] a_we; // Bank enables -logic [1:0][WGT_ADDR_BITS-1:0] a_addr; -logic [1:0][SIMD-1:0][WEIGHT_WIDTH-1:0] a_data_in; - -// -- REG -always_ff @( posedge clk ) begin : REG_PROC_WR - if(rst) begin - state_wr_C <= ST_WR_0; - - wr_pntr_C <= '0; - curr_pe_C <= '0; - end - else begin - state_wr_C <= state_wr_N; - - wr_pntr_C <= wr_pntr_N; - curr_pe_C <= curr_pe_N; - end -end - -// -- NSL -always_comb begin : NSL_PROC_WR - state_wr_N = state_wr_C; - - case (state_wr_C) - ST_WR_0: - if((curr_pe_C == PE - 1) && (wr_pntr_C == N_TLS - 1) && ivld) begin - state_wr_N = (state_rd_C == ST_RD_0) ? ST_WR_1 : ST_WR_0_WAIT; - end - - ST_WR_0_WAIT: - state_wr_N = (state_rd_C == ST_RD_0) ? ST_WR_1 : ST_WR_0_WAIT; - - ST_WR_1: - if((curr_pe_C == PE - 1) && (wr_pntr_C == N_TLS - 1) && ivld) begin - state_wr_N = (state_rd_C == ST_RD_1) ? ST_WR_0 : ST_WR_1_WAIT; - end - - ST_WR_1_WAIT: - state_wr_N = (state_rd_C == ST_RD_1) ? ST_WR_0 : ST_WR_1_WAIT; - - endcase -end - -// -- DP -always_comb begin : DP_PROC_WR - wr_pntr_N = wr_pntr_C; - curr_pe_N = curr_pe_C; - - // Input - irdy = 1'b0; - - // Buffers - a_we = '0; - for(int i = 0; i < 2; i++) begin - a_addr[i] = wr_pntr_C; - a_data_in[i] = idat; - end - - // Write and count - case (state_wr_C) - ST_WR_0, ST_WR_1: begin - irdy = 1'b1; - - if(ivld) begin - for(int i = 0; i < PE; i++) begin - if(curr_pe_C == i) begin - a_we[state_wr_C == ST_WR_1][i] = '1; - end - end - - curr_pe_N = (curr_pe_C == PE-1) ? 0 : curr_pe_C + 1; - wr_pntr_N = (curr_pe_C == PE-1) ? ((wr_pntr_C == N_TLS-1) ? 0 : wr_pntr_C + 1) : wr_pntr_C; - end - end - endcase - -end - -// ---------------------------------------------------------------------------- -// Reader -// ---------------------------------------------------------------------------- - -// -- Regs -logic [N_TLS_BITS-1:0] rd_pntr_C = '0, rd_pntr_N; -logic [N_REPS_BITS-1:0] reps_C = '0, reps_N; - -//logic [15:0] rd_pntr_C = '0, rd_pntr_N; -//logic [15:0] reps_C = '0, reps_N; - -logic [1:0] vld_s0_C = '0, vld_s0_N; -logic [1:0] vld_s1_C = '0, vld_s1_N; - -logic vld_C = '0, vld_N; -logic [PE-1:0][SIMD-1:0][WEIGHT_WIDTH-1:0] odat_C = '0, odat_N; - -// -- Signals -logic [1:0][WGT_ADDR_BITS-1:0] b_addr; -logic [1:0][PE-1:0][SIMD-1:0][WEIGHT_WIDTH-1:0] odat_ram; - -// -- REG -always_ff @( posedge clk ) begin : REG_PROC_RD - if(rst) begin - state_rd_C <= ST_RD_0; - - rd_pntr_C <= '0; - reps_C <= '0; - - vld_s0_C <= '0; - vld_s1_C <= '0; - vld_C <= '0; - odat_C <= 'X; - end - else begin - state_rd_C <= state_rd_N; - - rd_pntr_C <= rd_pntr_N; - reps_C <= reps_N; - - vld_s0_C <= vld_s0_N; - vld_s1_C <= vld_s1_N; - vld_C <= vld_N; - odat_C <= odat_N; - end -end - -// -- NSL -always_comb begin : NSL_PROC_RD - state_rd_N = state_rd_C; - - case (state_rd_C) - ST_RD_0: - if(ordy && ((state_wr_C == ST_WR_0) ? (wr_pntr_C > rd_pntr_C) : 1'b1)) begin - if((rd_pntr_C == N_TLS-1) && (reps_C == N_REPS-1)) begin - state_rd_N = ST_RD_1; - end - end - - ST_RD_1: - if(ordy && ((state_wr_C == ST_WR_1) ? (wr_pntr_C > rd_pntr_C) : 1'b1)) begin - if((rd_pntr_C == N_TLS-1) && (reps_C == N_REPS-1)) begin - state_rd_N = ST_RD_0; - end - end - endcase -end - -// -- DP -always_comb begin : DP_PROC_RD - rd_pntr_N = rd_pntr_C; - reps_N = reps_C; - - for(int i = 0; i < 2; i++) begin - vld_s0_N[i] = ordy ? 1'b0 : vld_s0_C[i]; - vld_s1_N[i] = ordy ? vld_s0_C[i] : vld_s1_C[i]; - end - - vld_N = ordy ? |vld_s1_C : vld_C; - odat_N = ordy ? (vld_s1_C[0] ? odat_ram[0] : odat_ram[1]) : odat_C; - - for(int i = 0; i < 2; i++) begin - b_addr[i] = rd_pntr_C; - end - - case(state_rd_C) - ST_RD_0: begin - if(ordy) begin - if((state_wr_C == ST_WR_0) ? (wr_pntr_C > rd_pntr_C) : 1'b1) begin - - vld_s0_N[0] = 1'b1; - - rd_pntr_N = (rd_pntr_C == N_TLS-1) ? 0 : rd_pntr_C + 1; - reps_N = (rd_pntr_C == N_TLS-1) ? ((reps_C == N_REPS-1) ? 0 : reps_C + 1) : reps_C; - end - end - end - - ST_RD_1: begin - if(ordy) begin - if((state_wr_C == ST_WR_1) ? (wr_pntr_C > rd_pntr_C) : 1'b1) begin - - vld_s0_N[1] = 1'b1; - - rd_pntr_N = (rd_pntr_C == N_TLS-1) ? 0 : rd_pntr_C + 1; - reps_N = (rd_pntr_C == N_TLS-1) ? ((reps_C == N_REPS-1) ? 0 : reps_C + 1) : reps_C; - end - end - end - - endcase - -end - -assign ovld = vld_C; -assign odat = odat_C; - -// ---------------------------------------------------------------------------- -// Weights -// ---------------------------------------------------------------------------- - -for(genvar i = 0; i < 2; i++) begin - for(genvar j = 0; j < PE; j++) begin - ram_p_c #( - .ADDR_BITS(WGT_ADDR_BITS), - .DATA_BITS(RAM_BITS), - .RAM_STYLE("block") - ) inst_ram_tp_c ( - .clk(clk), - .a_en(1'b1), - .a_we(a_we[i][j]), - .a_addr(a_addr[i]), - .b_en(ordy), - .b_addr(b_addr[i]), - .a_data_in(a_data_in[i]), - .a_data_out(), - .b_data_out(odat_ram[i][j]) - ); - end -end - -endmodule diff --git a/finn-rtllib/mvu_tiled/acc_stage.sv b/finn-rtllib/mvu_tiled/acc_stage.sv new file mode 100644 index 0000000000..7ab3704492 --- /dev/null +++ b/finn-rtllib/mvu_tiled/acc_stage.sv @@ -0,0 +1,144 @@ +/****************************************************************************** + * Copyright (C) 2024, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + *****************************************************************************/ + + +module acc_stage #( + int unsigned CHAINLEN, + int unsigned PE, + int unsigned ACCU_WIDTH, + int unsigned TH, + int unsigned TH_MAX = 2*TH +)( + input logic clk, + input logic rst, + input logic en, + + input logic [PE-1:0][CHAINLEN-1:0][ACCU_WIDTH-1:0] idat, + input logic ival, + input logic ilast, + + output logic [PE-1:0][ACCU_WIDTH-1:0] odat, + output logic oval +); + + //=== Adder Tree + Accumulator Add ====================================== + localparam int unsigned TREE_DEPTH = $clog2(CHAINLEN); + localparam int unsigned ADD_LAT = TREE_DEPTH + 1; + + logic [PE-1:0][ACCU_WIDTH-1:0] Acc; + logic [PE-1:0][ACCU_WIDTH-1:0] DatInt; + + for(genvar i = 0; i < PE; i++) begin : genAdd + // Tree reduction of CHAINLEN DSP partial products + logic [ACCU_WIDTH-1:0] add_arg[CHAINLEN]; + for(genvar k = 0; k < CHAINLEN; k++) + assign add_arg[k] = idat[i][k]; + + localparam int unsigned SUM_WIDTH = $clog2(CHAINLEN) + ACCU_WIDTH; + uwire [SUM_WIDTH-1:0] tree_sum; + add_multi #(.N(CHAINLEN), .DEPTH(TREE_DEPTH), .ARG_WIDTH(ACCU_WIDTH)) inst_add ( + .clk(clk), .rst(rst), .en(en), + .arg(add_arg), + .sum(tree_sum) + ); + + // Accumulator add (1 registered stage) + always_ff @(posedge clk) begin + if(rst) DatInt[i] <= 'x; + else if(en) DatInt[i] <= tree_sum[ACCU_WIDTH-1:0] + Acc[i]; + end + end : genAdd + + //=== Valid/Last Pipeline =============================================== + logic [ADD_LAT:0] Val; + logic [ADD_LAT:0] Last; + + assign Val[0] = ival; + assign Last[0] = ilast; + + always_ff @(posedge clk) begin + if(rst) begin + for(int i = 1; i <= ADD_LAT; i++) begin + Val [i] <= 0; + Last[i] <= 'x; + end + end + else if(en) begin + for(int i = 1; i <= ADD_LAT; i++) begin + Val [i] <= Val [i-1]; + Last[i] <= Last[i-1]; + end + end + end + + uwire val_out = Val[ADD_LAT]; + uwire last_out = Last[ADD_LAT]; + uwire inc_acc = Val[ADD_LAT-1]; + + //=== Prep Counter ====================================================== + logic signed [$clog2(TH):0] CntPrep = -TH; + uwire prep = CntPrep[$left(CntPrep)]; + always_ff @(posedge clk) begin + if(rst) CntPrep <= -TH; + else CntPrep <= CntPrep + prep; + end + + //=== Accumulation FIFO ================================================= + Q_srl #( + .depth(TH_MAX), + .width(PE*ACCU_WIDTH) + ) inst_acc ( + .clock(clk), + .reset(rst), + .i_d(prep? {PE*ACCU_WIDTH{1'b0}} : (last_out? {PE*ACCU_WIDTH{1'b0}} : DatInt)), + .i_v(prep? 1 : (en? val_out : 0)), + .i_r(), + .o_d(Acc), + .o_v(), + .o_r(en & inc_acc), + .count(), + .maxcount() + ); + + //=== Output Stage ====================================================== + always_ff @(posedge clk) begin + if(rst) begin + odat <= 'x; + oval <= 0; + end + else if(en) begin + odat <= DatInt; + oval <= val_out && last_out; + end + end + +endmodule : acc_stage diff --git a/finn-rtllib/mvu_tiled/cu_mvau_tiled.sv b/finn-rtllib/mvu_tiled/cu_mvau_tiled.sv new file mode 100644 index 0000000000..c42cd63643 --- /dev/null +++ b/finn-rtllib/mvu_tiled/cu_mvau_tiled.sv @@ -0,0 +1,284 @@ +/****************************************************************************** + * Copyright (C) 2024, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + *****************************************************************************/ + +module cu_mvau_tiled #( + int unsigned PE, + int unsigned SIMD, + int unsigned TH, + int unsigned WEIGHT_WIDTH, + int unsigned ACTIVATION_WIDTH, + int unsigned ACCU_WIDTH, + + bit SIGNED_ACTIVATIONS = 1, + localparam int unsigned WEIGHT_ELEMENTS = PE*SIMD +)( + input logic clk, + input logic rst, + input logic en, + + input logic ilast, + input logic ivld, + input logic [WEIGHT_ELEMENTS-1:0][WEIGHT_WIDTH-1:0] w, + input logic [SIMD-1:0][ACTIVATION_WIDTH-1:0] a, + + output logic ovld, + output logic [PE-1:0][ACCU_WIDTH-1:0] p +); + + //=== Startup Recovery Watchdog ========================================= + // The DSP slice needs 100ns of recovery time after initial startup before + // being able to ingest input properly. This watchdog discovers violating + // stimuli during simulation and produces a corresponding warning. + if(1) begin : blkRecoveryWatch + logic Dirty = 1; + initial begin + #100ns; + Dirty <= 0; + end + + always_ff @(posedge clk) begin + assert(!Dirty || rst || !en) else begin + $warning("%m: Feeding input during DSP startup recovery. Expect functional errors."); + end + end + end : blkRecoveryWatch + + //=== Input Formatting ================================================== + localparam int unsigned CHAINLEN = (SIMD+2)/3; + uwire [26:0] a_in_i[CHAINLEN]; + uwire [23:0] b_in_i[PE][CHAINLEN]; + // Array with packed dimension > 256 cannot be handled out-of-the-box with PyVerilator + uwire [PE-1:0][CHAINLEN-1:0][ACCU_WIDTH-1:0] pout; + + //--- Valid/Last Pipeline ----------------------------------------------- + localparam int unsigned DSP_PIPELINE_STAGES = 1; + logic L[0:1+DSP_PIPELINE_STAGES] = '{default: 0}; + logic V[0:1+DSP_PIPELINE_STAGES] = '{default: 0}; + + always_ff @(posedge clk) begin + if(rst) begin + L <= '{default: 0}; + V <= '{default: 0}; + end + else if(en) begin + L[1+DSP_PIPELINE_STAGES] <= ilast; + L[0:DSP_PIPELINE_STAGES] <= L[1:1+DSP_PIPELINE_STAGES]; + + V[1+DSP_PIPELINE_STAGES] <= ivld; + V[0:DSP_PIPELINE_STAGES] <= V[1:1+DSP_PIPELINE_STAGES]; + end + end + + uwire last = L[0]; + uwire vld = V[0]; + + //--- Activation Padding ------------------------------------------------ + localparam int unsigned PAD_BITS_ACT = 9 - ACTIVATION_WIDTH; + for(genvar i = 0; i < CHAINLEN; i++) begin : genActSIMD + localparam int unsigned LANES_OCCUPIED = i == CHAINLEN-1? SIMD - 3*i : 3; + + for(genvar j = 0; j < LANES_OCCUPIED; j++) begin : genAin + assign a_in_i[i][9*j +: 9] = + SIGNED_ACTIVATIONS? + PAD_BITS_ACT == 0? a[3*i+j] : { {PAD_BITS_ACT{a[3*i+j][ACTIVATION_WIDTH-1]}}, a[3*i+j] } : + PAD_BITS_ACT == 0? a[3*i+j] : { {PAD_BITS_ACT{1'b0}}, a[3*i+j] }; + end : genAin + for(genvar j = LANES_OCCUPIED; j < 3; j++) begin : genAinZero + assign a_in_i[i][9*j +: 9] = 9'd0; + end : genAinZero + end : genActSIMD + + //--- Weight Padding ---------------------------------------------------- + localparam int unsigned PAD_BITS_WEIGHT = 8 - WEIGHT_WIDTH; + + for(genvar i = 0; i < PE; i++) begin : genWeightPE + for(genvar j = 0; j < CHAINLEN; j++) begin : genWeightSIMD + localparam int unsigned LANES_OCCUPIED = j == CHAINLEN-1? SIMD - 3*j : 3; + + for(genvar k = 0; k < LANES_OCCUPIED; k++) begin : genBin + assign b_in_i[i][j][8*k +: 8] = + PAD_BITS_WEIGHT == 0? w[SIMD*i+3*j+k] : { {PAD_BITS_WEIGHT{w[SIMD*i+3*j+k][WEIGHT_WIDTH-1]}}, w[SIMD*i+3*j+k] }; + end : genBin + for(genvar k = LANES_OCCUPIED; k < 3; k++) begin : genBinZero + assign b_in_i[i][j][8*k +: 8] = 8'd0; + end : genBinZero + end : genWeightSIMD + end : genWeightPE + + //=== DSP Instances ===================================================== + for(genvar i = 0; i < PE; i++) begin : genPE + for(genvar j = 0; j < CHAINLEN; j++) begin : genChain + localparam int unsigned INTERNAL_REGS = 1; + localparam bit PREG = 1; + + DSP58 #( + // Feature Control Attributes: Data Path Selection + .AMULTSEL("A"), + .A_INPUT("DIRECT"), + .BMULTSEL("B"), + .B_INPUT("DIRECT"), + .DSP_MODE("INT8"), + .PREADDINSEL("A"), + .RND(58'h000000000000000), + .USE_MULT("MULTIPLY"), + .USE_SIMD("ONE58"), + .USE_WIDEXOR("FALSE"), + .XORSIMD("XOR24_34_58_116"), + // Pattern Detector Attributes + .AUTORESET_PATDET("NO_RESET"), + .AUTORESET_PRIORITY("RESET"), + .MASK(58'h0ffffffffffffff), + .PATTERN(58'h000000000000000), + .SEL_MASK("MASK"), + .SEL_PATTERN("PATTERN"), + .USE_PATTERN_DETECT("NO_PATDET"), + // Programmable Inversion Attributes + .IS_ALUMODE_INVERTED(4'b0000), + .IS_CARRYIN_INVERTED(1'b0), + .IS_CLK_INVERTED(1'b0), + .IS_INMODE_INVERTED(5'b00000), + .IS_NEGATE_INVERTED(3'b000), + .IS_OPMODE_INVERTED({ + 2'b00, // W: 0 (unused, accumulation is external) + 3'b000, // Z: 0 (unused) + 2'b01, // Y: M (multiply) + 2'b01 // X: M (multiply) + }), // Static OPMODE='0 inverted to select P = M (multiply-only) + .IS_RSTALLCARRYIN_INVERTED(1'b0), + .IS_RSTALUMODE_INVERTED(1'b0), + .IS_RSTA_INVERTED(1'b0), + .IS_RSTB_INVERTED(1'b0), + .IS_RSTCTRL_INVERTED(1'b0), + .IS_RSTC_INVERTED(1'b0), + .IS_RSTD_INVERTED(1'b0), + .IS_RSTINMODE_INVERTED(1'b0), + .IS_RSTM_INVERTED(1'b0), + .IS_RSTP_INVERTED(1'b0), + // Register Control Attributes + .ACASCREG(INTERNAL_REGS), + .ADREG(0), + .ALUMODEREG(0), + .AREG(INTERNAL_REGS), + .BCASCREG(INTERNAL_REGS), + .BREG(INTERNAL_REGS), + .CARRYINREG(0), + .CARRYINSELREG(0), + .CREG(0), + .DREG(0), + .INMODEREG(1), + .MREG(1), + .OPMODEREG(0), // No register needed: OPMODE is static + .PREG(PREG), + .RESET_MODE("SYNC") + ) + DSP58_inst ( + // Cascade outputs + .ACOUT(), + .BCOUT(), + .CARRYCASCOUT(), + .MULTSIGNOUT(), + .PCOUT(), + // Control outputs + .OVERFLOW(), + .PATTERNBDETECT(), + .PATTERNDETECT(), + .UNDERFLOW(), + // Data outputs + .CARRYOUT(), + .P(pout[i][j]), + .XOROUT(), + // Cascade inputs + .ACIN('x), + .BCIN('x), + .CARRYCASCIN('x), + .MULTSIGNIN('x), + .PCIN('x), + // Control inputs + .ALUMODE(4'h0), + .CARRYINSEL('0), + .CLK(clk), + .INMODE({ + INTERNAL_REGS == 2? 1'b0 : 1'b1, + 2'b00, + 1'b0, + INTERNAL_REGS == 2? 1'b0 : 1'b1 + }), + .NEGATE('0), + .OPMODE('0), // Static (inverted to X=Y=M, W=Z=0) + // Data inputs + .A({ 7'bx, a_in_i[j] }), + .B(b_in_i[i][j]), + .C('x), + .CARRYIN('0), + .D('x), + // Reset/Clock Enable inputs + .ASYNC_RST('0), + .CEA1(en), + .CEA2(INTERNAL_REGS == 2? en : '0), + .CEAD('0), + .CEALUMODE('0), + .CEB1(en), + .CEB2(INTERNAL_REGS == 2? en : '0), + .CEC('0), + .CECARRYIN('0), + .CECTRL('0), + .CED('0), + .CEINMODE(en), + .CEM(en), + .CEP(PREG && en), + .RSTA('0), + .RSTALLCARRYIN('0), + .RSTALUMODE('0), + .RSTB('0), + .RSTC('0), + .RSTCTRL('0), + .RSTD('0), + .RSTINMODE(rst), + .RSTM('0), + .RSTP('0) + ); + end : genChain + end : genPE + + //=== Accumulation ====================================================== + acc_stage #(.CHAINLEN(CHAINLEN), .PE(PE), .ACCU_WIDTH(ACCU_WIDTH), .TH(TH)) inst_acc_stage ( + .clk(clk), + .rst(rst), + .en(en), + .idat(pout), + .ival(vld), + .ilast(last), + .odat(p), + .oval(ovld) + ); + +endmodule : cu_mvau_tiled diff --git a/finn-rtllib/mvu_tiled/input_gen.sv b/finn-rtllib/mvu_tiled/input_gen.sv new file mode 100644 index 0000000000..16e2a443fd --- /dev/null +++ b/finn-rtllib/mvu_tiled/input_gen.sv @@ -0,0 +1,265 @@ +/**************************************************************************** + * Copyright Advanced Micro Devices, Inc. + * SPDX-License-Identifier: BSD-3-Clause + * + * @author Thomas B. Preußer + * @brief + * Generic sliding window / input generator driven by a perfect loop nest. + * + * A loop nest: + * + * for(i0 = 0; i0 < DIMS[0]; i0++) + * for(i1 = 0; i1 < DIMS[1]; i1++) + * ... + * for(in = 0; in < DIMS[D-1]; in++) + * emit(buf[COEFS[0]*i0 + COEFS[1]*i1 + ... + COEFS[D-1]*in]) + * + * is encoded by the array parameters DIMS and COEFS. The module reads + * a linear input stream into a circular buffer and replays elements + * according to the loop nest addressing, supporting arbitrary strides, + * dilations, and transpositions. + * + * FM_SIZE is the number of input elements per feature map (period of the + * input stream). The olst output exposes the level-completion cascade + * term[D-1:0] synchronous with each output beat. + ***************************************************************************/ + +module input_gen #( + int unsigned DATA_WIDTH, + int unsigned FM_SIZE, + int unsigned D, + int unsigned DIMS[D], + int unsigned COEFS[D] +)( + input logic clk, + input logic rst, + + // Input Stream + input logic [DATA_WIDTH-1:0] idat, + input logic ivld, + output logic irdy, + + // Output Stream + output logic [DATA_WIDTH-1:0] odat, + output logic ovld, + output logic [D-1:0] olst, + output logic [D-1:0] odone, + input logic ordy +); + + //=== Parameter Validation ============================================== + initial begin + if(D == 0) begin + $error("%m: D must be at least 1."); + $finish; + end + for(int unsigned i = 0; i < D; i++) begin + if(DIMS[i] == 0) begin + $error("%m: DIMS[%0d] must be positive.", i); + $finish; + end + end + end + + //=== Elaboration-Time Nest Computations ================================ + // Parent coefficient per level (W in the HLS Nest<> encoding): + // W[0] = FM_SIZE, W[i>0] = COEFS[i-1]. + typedef int unsigned w_arr_t[D+1]; + function automatic w_arr_t INIT_W(); + automatic w_arr_t a; + a[0] = FM_SIZE; + for(int unsigned i = 0; i < D; i++) a[i+1] = COEFS[i]; + return a; + endfunction : INIT_W + localparam w_arr_t W = INIT_W(); + + // Free-pointer responsibility flag per level. + // R_FLAG[i] is the R flag passed into level i from its parent. + typedef bit r_flag_arr_t[D+1]; + function automatic r_flag_arr_t INIT_R_FLAG(); + automatic r_flag_arr_t a; + a[0] = 1; + for(int unsigned i = 1; i <= D; i++) + a[i] = a[i-1] && (COEFS[i-1] > 0) + && (COEFS[i-1] * DIMS[i-1] <= W[i-1]); + return a; + endfunction : INIT_R_FLAG + localparam r_flag_arr_t R_FLAG = INIT_R_FLAG(); + + // Terminal read-pointer increment when level i completes. + // Index D covers the default innermost-advance case. + typedef int rp_inc_arr_t[D+1]; + function automatic rp_inc_arr_t INIT_RP_INC(); + automatic rp_inc_arr_t a; + automatic int unsigned rw = 0; // cumulative rp_rewind, built inside out + for(int i = D; i >= 0; i--) begin + if(i < int'(D)) rw = (DIMS[i]-1) * COEFS[i] + rw; + a[i] = int'(W[i]) - int'(rw); + end + return a; + endfunction : INIT_RP_INC + localparam rp_inc_arr_t TERMINAL_RP_INC = INIT_RP_INC(); + + // Negated terminal free-pointer increment when level i completes. + // Stored negated for direct use in the negated capacity counter. + // Index D covers the default innermost-advance case. + typedef int fp_inc_arr_t[D+1]; + function automatic fp_inc_arr_t INIT_FP_INC(); + automatic fp_inc_arr_t a; + automatic int unsigned fw = 0; // cumulative fp_rewind, built inside out + for(int i = D; i >= 0; i--) begin + if(i < int'(D)) fw = R_FLAG[i+1]? (DIMS[i]-1) * COEFS[i] + fw : 0; + a[i] = R_FLAG[i]? int'(fw) - int'(W[i]) : 0; + end + return a; + endfunction : INIT_FP_INC + localparam fp_inc_arr_t TERMINAL_FP_INC = INIT_FP_INC(); + + // Maximum buffer capacity requirement: the larger of the max backward + // read-pointer retraction and the max read-free pointer gap. + function automatic int unsigned INIT_MAX_OCCUPANCY(); + automatic int unsigned m = 0; + automatic int unsigned rw = 0; + automatic int unsigned fw = 0; + for(int unsigned i = 0; i < D; i++) begin + automatic int t = -TERMINAL_RP_INC[i]; + if(t > int'(m)) m = t; + end + for(int i = D-1; i >= 0; i--) begin + rw = (DIMS[i]-1) * COEFS[i] + rw; + fw = R_FLAG[i+1]? (DIMS[i]-1) * COEFS[i] + fw : 0; + if(rw - fw > m) m = rw - fw; + end + return m; + endfunction : INIT_MAX_OCCUPANCY + + //=== Buffer Sizing ===================================================== + localparam int unsigned WP_DELAY = 1; + localparam int unsigned MAX_OCCUPANCY = INIT_MAX_OCCUPANCY(); + localparam int unsigned ADDR_BITS = $clog2(MAX_OCCUPANCY + WP_DELAY + 2); + localparam int unsigned BUF_SIZE = 1 << ADDR_BITS; + + // Pointer type: one extra bit for signed wrap-around detection. + typedef logic signed [ADDR_BITS:0] ptr_t; + + // Pointer increment type: must accommodate the largest absolute increment. + function automatic int unsigned INIT_MAX_ABS_INC(); + automatic int unsigned m = 0; + for(int unsigned i = 0; i <= D; i++) begin + automatic int unsigned rp_abs = TERMINAL_RP_INC[i] < 0? -TERMINAL_RP_INC[i] : TERMINAL_RP_INC[i]; + automatic int unsigned fp_abs = TERMINAL_FP_INC[i] < 0? -TERMINAL_FP_INC[i] : TERMINAL_FP_INC[i]; + if(rp_abs > m) m = rp_abs; + if(fp_abs > m) m = fp_abs; + end + return m; + endfunction : INIT_MAX_ABS_INC + localparam int unsigned INC_BITS = 1 + $clog2(INIT_MAX_ABS_INC() + 1); + typedef logic signed [INC_BITS-1:0] inc_t; + + //=== Nest Counters ===================================================== + // done[i]: level i has exhausted its iterations (sign-bit of Cnt). + // term[i]: level i and all inner levels completed simultaneously. + uwire [D:0] done; + uwire [D:0] term; + assign done[D] = 1; + assign term[D] = 1; + + uwire advance; // forward-declared, defined in output section + + for(genvar i = 0; i < D; i++) begin : genCnt + uwire step = advance && term[i+1]; + + if(DIMS[i] == 1) begin : genTrivial + assign done[i] = 1; + end : genTrivial + else begin : genCounter + logic signed [$clog2(DIMS[i]-1):0] Cnt = DIMS[i]-2; // DIMS[i]-2, ..., 1, 0, -1 (done) + always_ff @(posedge clk) begin + if(rst) Cnt <= DIMS[i]-2; + else if(step) Cnt <= Cnt + (done[i]? $signed(DIMS[i])-1 : -1); + end + assign done[i] = Cnt[$left(Cnt)]; + end : genCounter + + assign term[i] = term[i+1] && done[i]; + end : genCnt + + //=== Pointer Increment Mux (Combinational) ============================= + inc_t rp_inc; + inc_t fp_inc; + always_comb begin + rp_inc = 0; + fp_inc = 0; + for(int i = D; i >= 0; i--) begin + if(term[i]) begin + rp_inc = TERMINAL_RP_INC[i]; + if(R_FLAG[i]) fp_inc = TERMINAL_FP_INC[i]; + end + end + end + + //=== Circular Buffer and Pointer Management ============================ + logic [DATA_WIDTH-1:0] Buf[BUF_SIZE]; + ptr_t Wp = 0; + ptr_t WpZ = 0; + ptr_t Rp = 0; + ptr_t Cap = -BUF_SIZE+1; // -BUF_SIZE+1, ..., -1, 0 (full) + + uwire has_data = $signed(Rp - WpZ) < 0; + + assign irdy = Cap[$left(Cap)]; + + // Buffer memory — one write port, one registered read port. + // Speculative pre-fetch: on advance, read from the next Rp so that + // BufRd is ready without a settling cycle. + logic [DATA_WIDTH-1:0] BufRd; + uwire ptr_t rd_ptr = Rp + (advance? ptr_t'(rp_inc) : ptr_t'(0)); + always_ff @(posedge clk) begin + if(irdy) Buf[Wp[ADDR_BITS-1:0]] <= idat; + BufRd <= Buf[rd_ptr[ADDR_BITS-1:0]]; + end + + always_ff @(posedge clk) begin + if(rst) begin + Wp <= 0; + WpZ <= 0; + Rp <= 0; + Cap <= -BUF_SIZE+1; + end + else begin + automatic logic istep = irdy && ivld; + WpZ <= Wp; + Wp <= Wp + istep; + Cap <= Cap + (advance? ptr_t'(fp_inc) : ptr_t'(0)) + istep; + if(advance) Rp <= Rp + ptr_t'(rp_inc); + end + end + + //=== Output Stage ====================================================== + logic OVld = 0; + logic [DATA_WIDTH-1:0] OBuf = 'x; + logic [D-1:0] OLst = 'x; + logic [D-1:0] ODone = 'x; + always_ff @(posedge clk) begin + if(rst) begin + OVld <= 0; + OBuf <= 'x; + OLst <= 'x; + ODone <= 'x; + end + else if(!OVld || ordy) begin + OVld <= has_data; + OBuf <= BufRd; + OLst <= term[D-1:0]; + ODone <= done[D-1:0]; + end + end + + assign advance = has_data && (!OVld || ordy); + + assign odat = OBuf; + assign ovld = OVld; + assign olst = OLst; + assign odone = ODone; + +endmodule : input_gen diff --git a/finn-rtllib/mvu_tiled/mmu/1d/collect_out_1d.sv b/finn-rtllib/mvu_tiled/mmu/1d/collect_out_1d.sv new file mode 100644 index 0000000000..ab874aac47 --- /dev/null +++ b/finn-rtllib/mvu_tiled/mmu/1d/collect_out_1d.sv @@ -0,0 +1,111 @@ +/****************************************************************************** + * Copyright (C) 2024, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + *****************************************************************************/ + +module collect_out_1d #( + int unsigned PE, + int unsigned ACCU_WIDTH, + + int unsigned QDEPTH = 2 * PE, + int unsigned QCNT_BITS = $clog2(QDEPTH), + int unsigned Q_MAX = PE, + + int unsigned N_DCPL_STAGES = 2 +)( + // Global Control + input logic clk, + input logic rst, + + output logic en, + + // Input Stream + input logic [PE-1:0][ACCU_WIDTH-1:0] p_tdata, + input logic p_tvalid, + + // Output Stream + output logic [PE-1:0][ACCU_WIDTH-1:0] m_axis_tdata, + output logic m_axis_tvalid, + input logic m_axis_tready +); + +// Queueing +// --------------------------------------------------------------------- + logic q_in_tready; + logic q_out_tready, q_out_tvalid; + logic [PE-1:0][ACCU_WIDTH-1:0] q_out_tdata; + logic [QCNT_BITS-1:0] q_count; + logic en_int; + + for(genvar i = 0; i < PE; i++) begin + if(i == 0) begin + Q_srl #( + .depth(QDEPTH), + .width(ACCU_WIDTH) + ) inst_queue ( + .clock(clk), .reset(rst), + .count(q_count), .maxcount(), + .i_v(p_tvalid), .i_r(q_in_tready), .i_d(p_tdata[i]), + .o_v(q_out_tvalid), .o_r(q_out_tready), .o_d(q_out_tdata[i]) + ); + end else begin + Q_srl #( + .depth(QDEPTH), + .width(ACCU_WIDTH) + ) inst_queue ( + .clock(clk), .reset(rst), + .count(), .maxcount(), + .i_v(p_tvalid), .i_r(), .i_d(p_tdata[i]), + .o_v(), .o_r(q_out_tready), .o_d(q_out_tdata[i]) + ); + end + end + + // Global enable + assign en_int = !(q_count > Q_MAX); + + always_ff @( posedge clk ) begin + if(rst) begin + en <= 1'b0; + end + else begin + en <= en_int; + end + end + +// Output +// --------------------------------------------------------------------- + skid #(.DATA_WIDTH(PE*ACCU_WIDTH), .FEED_STAGES(N_DCPL_STAGES)) inst_oreg ( + .clk(clk), .rst(rst), + .idat(q_out_tdata), .ivld(q_out_tvalid), .irdy(q_out_tready), + .odat(m_axis_tdata), .ovld(m_axis_tvalid), .ordy(m_axis_tready) + ); + +endmodule diff --git a/finn-rtllib/mvu_tiled/mmu/1d/cu_mmau_1d.sv b/finn-rtllib/mvu_tiled/mmu/1d/cu_mmau_1d.sv new file mode 100644 index 0000000000..0f13083d09 --- /dev/null +++ b/finn-rtllib/mvu_tiled/mmu/1d/cu_mmau_1d.sv @@ -0,0 +1,403 @@ +/****************************************************************************** + * Copyright (C) 2025, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * @brief Compute unit (DSP grid) - MMAU + * @author Dario Korolija + *****************************************************************************/ + +module cu_mmau_1d #( + int unsigned PE, + int unsigned CLEN, + int unsigned CU_SIMD, + + int unsigned ACTIVATION_WIDTH, + int unsigned WEIGHT_WIDTH, + int unsigned ACCU_WIDTH, + + bit SIGNED_ACTIVATIONS = 1, + int unsigned FORCE_BEHAVIOURAL = 0 + ) ( + // Global Control + input logic clk, + input logic rst, + + // Enable + output logic en, + + // Input + input logic ivld, + input logic [CLEN-1:0] ilast, + input logic [CLEN-1:0][CU_SIMD-1:0][ACTIVATION_WIDTH-1:0] a, + input logic [PE-1:0][CU_SIMD-1:0][WEIGHT_WIDTH-1:0] w, + + // Ouput + output logic m_axis_tvalid, + input logic m_axis_tready, + output logic [PE-1:0][ACCU_WIDTH-1:0] m_axis_tdata + ); + +// Startup Recovery Watchdog +// The DSP slice needs 100ns of recovery time after initial startup before +// being able to ingest input properly. This watchdog discovers violating +// stimuli during simulation and produces a corresponding warning. +//------------------------------------------------------------------------------------ + if(1) begin : blkRecoveryWatch + logic Dirty = 1; + initial begin + #100ns; + Dirty <= 0; + end + + always_ff @(posedge clk) begin + assert(!Dirty || rst) else begin + $warning("%m: Feeding input during DSP startup recovery. Expect functional errors."); + end + end + end : blkRecoveryWatch + +// Shifts - activations and weights +//------------------------------------------------------------------------------------ + localparam int unsigned PAD_BITS_ACT = 9 - ACTIVATION_WIDTH; + localparam int unsigned PAD_BITS_WEIGHT = 8 - WEIGHT_WIDTH; + + logic [CLEN:0][PE-1:0][CU_SIMD*WEIGHT_WIDTH-1:0] Wc; + logic [CLEN-1:0][PE-1:0][23:0] Wc_int; + + for(genvar i = 0; i < PE; i++) begin + assign Wc[0][i] = w[i]; + + for (genvar k = 0; k < CU_SIMD; k++) begin + assign Wc_int[0][i][8*k +: 8] = + PAD_BITS_WEIGHT == 0 ? Wc[0][i][WEIGHT_WIDTH*k+:WEIGHT_WIDTH] : { {PAD_BITS_WEIGHT{Wc[0][i][k*WEIGHT_WIDTH+WEIGHT_WIDTH-1]}}, Wc[0][i][k*WEIGHT_WIDTH+:WEIGHT_WIDTH] }; + end + end + + /* + always_ff @(posedge clk) begin + if(rst) begin + for(int i = 1; i < CLEN; i++) begin + for(int j = 0; j < PE; j++) begin + Wc[i][j] <= 'X; + end + end + end + for(int i = 1; i < CLEN; i++) begin + for(int j = 0; j < PE; j++) begin + if(ivld) begin + Wc[i][j] <= Wc[i-1][j]; + end + end + end + end + */ + +// Shifts - per DSP +//------------------------------------------------------------------------------------ + localparam int unsigned DSP_PIPELINE_STAGES = 3; + logic [CLEN-1:0][DSP_PIPELINE_STAGES:0] Lc; + + for(genvar i = 0; i < CLEN; i++) begin + assign Lc[i][0] = ilast[i]; + end + + always_ff @(posedge clk) begin + if(rst) begin + for(int i = 0; i < CLEN; i++) begin + for(int k = 1; k <= DSP_PIPELINE_STAGES; k++) begin + Lc[i][k] <= 'X; + end + end + end + else begin + for(int i = 0; i < CLEN; i++) begin + for(int k = 1; k <= DSP_PIPELINE_STAGES; k++) begin + if(ivld) begin + Lc[i][k] <= Lc[i][k-1]; + end + end + end + end + end + +// Instantiate PE x CLEN DSPs +//------------------------------------------------------------------------------------ + logic [CLEN-1:0][PE-1:0][ACCU_WIDTH-1:0] pout; + + /* if(FORCE_BEHAVIOURAL == 1) begin + logic [CLEN-1:0][CU_SIMD*ACTIVATION_WIDTH-1:0] Ac_int; + logic [CLEN-1:0][PE-1:0][CU_SIMD*WEIGHT_WIDTH-1:0] Wc_int; + logic [CLEN-1:0][PE-1:0][CU_SIMD-1:0][ACCU_WIDTH-1:0] Mc_int_part; + logic [CLEN-1:0][PE-1:0][ACCU_WIDTH-1:0] Mc_int_sum; + logic [CLEN-1:0][PE-1:0][ACCU_WIDTH-1:0] Mc_int; + + + for (genvar i = 0; i < CLEN; i++) begin + always_ff @(posedge clk) begin + if(rst) begin + Ac_int[i] <= 'X; + end else begin + if(ivld) begin + Ac_int[i] <= a[i]; + end + end + end + + for (genvar j = 0; j < PE; j++) begin + always_comb begin + Mc_int_sum[i][j] = 0; + + for(int k = 0; k < CU_SIMD; k++) begin + Mc_int_part[i][j][k] = $signed(Ac_int[i][k*ACTIVATION_WIDTH+:ACTIVATION_WIDTH]) * $signed(Wc_int[i][j][k*WEIGHT_WIDTH+:WEIGHT_WIDTH]); + Mc_int_sum[i][j] = $signed(Mc_int_sum[i][j]) + $signed(Mc_int_part[i][j][k]); + end + end + + always_ff @(posedge clk) begin + if(rst) begin + Wc_int[i][j] <= '0; + Mc_int[i][j] <= '0; + pout[i][j] <= '0; + end else begin + if(ivld) begin + Wc_int[i][j] <= Wc[i][j]; + Mc_int[i][j] <= $signed(Mc_int_sum[i][j]); + pout[i][j] <= Lc[i][DSP_PIPELINE_STAGES] ? $signed(Mc_int[i][j]) : $signed(Mc_int[i][j]) + $signed(pout[i][j]); + end + end + end + end + end + end else begin */ + localparam int INTERNAL_REGS = 1; // 1 : 0 + localparam bit PREG = 1; + localparam int CC_LEN = CLEN / 4; + + logic [CLEN-1:0][26:0] Ac_int; + //logic [CLEN-1:0][PE-1:0][23:0] Wc_int; + logic [CLEN-1:0][PE-1:0][23:0] tmp_cc; + + for (genvar i = 0; i < CLEN; i++) begin + for (genvar k = 0; k < CU_SIMD; k++) begin + assign Ac_int[i][9*k +: 9] = + SIGNED_ACTIVATIONS ? PAD_BITS_ACT == 0 ? a[i][k] : { {PAD_BITS_ACT{a[i][k][ACTIVATION_WIDTH-1]}}, a[i][k] } + : PAD_BITS_ACT == 0 ? a[i][k] : { {PAD_BITS_ACT{1'b0}}, a[i][k] } ; + end + + + for (genvar j = 0; j < PE; j++) begin + /* for (genvar k = 0; k < CU_SIMD; k++) begin + assign Wc_int[i][j][8*k +: 8] = + PAD_BITS_WEIGHT == 0 ? Wc[i][j][WEIGHT_WIDTH*k+:WEIGHT_WIDTH] : { {PAD_BITS_WEIGHT{Wc[i][j][k*WEIGHT_WIDTH+WEIGHT_WIDTH-1]}}, Wc[i][j][k*WEIGHT_WIDTH+:WEIGHT_WIDTH] }; + end */ + + + DSP58 #( + // Feature Control Attributes: Data Path Selection + .AMULTSEL("A"), // Selects A input to multiplier (A, AD) + .A_INPUT("DIRECT"), // Selects A input source, "DIRECT" (A port) or "CASCADE" (ACIN port) + .BMULTSEL("B"), // Selects B input to multiplier (AD, B) + .B_INPUT((i % CC_LEN == 0) ? "DIRECT" : "CASCADE"), // Selects B input source, "DIRECT" (B port) or "CASCADE" (BCIN port) + .DSP_MODE("INT8"), // Configures DSP to a particular mode of operation. Set to INT24 for + // legacy mode. + .PREADDINSEL("A"), // Selects input to pre-adder (A, B) + .RND(58'h000000000000000), // Rounding Constant + .USE_MULT("MULTIPLY"), // Select multiplier usage (DYNAMIC, MULTIPLY, NONE) + .USE_SIMD("ONE58"), // SIMD selection (FOUR12, ONE58, TWO24) + .USE_WIDEXOR("FALSE"), // Use the Wide XOR function (FALSE, TRUE) + .XORSIMD("XOR24_34_58_116"), // Mode of operation for the Wide XOR (XOR12_22, XOR24_34_58_116) + // Pattern Detector Attributes: Pattern Detection Configuration + .AUTORESET_PATDET("NO_RESET"), // NO_RESET, RESET_MATCH, RESET_NOT_MATCH + .AUTORESET_PRIORITY("RESET"), // Priority of AUTORESET vs. CEP (CEP, RESET). + .MASK(58'h0ffffffffffffff), // 58-bit mask value for pattern detect (1=ignore) + .PATTERN(58'h000000000000000), // 58-bit pattern match for pattern detect + .SEL_MASK("MASK"), // C, MASK, ROUNDING_MODE1, ROUNDING_MODE2 + .SEL_PATTERN("PATTERN"), // Select pattern value (C, PATTERN) + .USE_PATTERN_DETECT("NO_PATDET"), // Enable pattern detect (NO_PATDET, PATDET) + // Programmable Inversion Attributes: Specifies built-in programmable inversion on specific pins + .IS_ALUMODE_INVERTED(4'b0000), // Optional inversion for ALUMODE + .IS_CARRYIN_INVERTED(1'b0), // Optional inversion for CARRYIN + .IS_CLK_INVERTED(1'b0), // Optional inversion for CLK + .IS_INMODE_INVERTED(5'b00000), // Optional inversion for INMODE + .IS_NEGATE_INVERTED(3'b000), // Optional inversion for NEGATE + .IS_OPMODE_INVERTED({2'b00, // W: LAST ? 0 : P + 3'b000, // Z: 0 + 2'b01, // Y : M + 2'b01 // X: M + }), // Optional inversion for OPMODE + .IS_RSTALLCARRYIN_INVERTED(1'b0), // Optional inversion for RSTALLCARRYIN + .IS_RSTALUMODE_INVERTED(1'b0), // Optional inversion for RSTALUMODE + .IS_RSTA_INVERTED(1'b0), // Optional inversion for RSTA + .IS_RSTB_INVERTED(1'b0), // Optional inversion for RSTB + .IS_RSTCTRL_INVERTED(1'b0), // Optional inversion for STCONJUGATE_A + .IS_RSTC_INVERTED(1'b0), // Optional inversion for RSTC + .IS_RSTD_INVERTED(1'b0), // Optional inversion for RSTD + .IS_RSTINMODE_INVERTED(1'b0), // Optional inversion for RSTINMODE + .IS_RSTM_INVERTED(1'b0), // Optional inversion for RSTM + .IS_RSTP_INVERTED(1'b0), // Optional inversion for RSTP + // Register Control Attributes: Pipeline Register Configuration + .ACASCREG(INTERNAL_REGS), // Number of pipeline stages between A/ACIN and ACOUT (0-2) + .ADREG(0), // Pipeline stages for pre-adder (0-1) + .ALUMODEREG(0), // Pipeline stages for ALUMODE (0-1) + .AREG(INTERNAL_REGS), // Pipeline stages for A (0-2) + .BCASCREG(INTERNAL_REGS), // Number of pipeline stages between B/BCIN and BCOUT (0-2) + .BREG(INTERNAL_REGS), // Pipeline stages for B (0-2) + .CARRYINREG(0), // Pipeline stages for CARRYIN (0-1) + .CARRYINSELREG(0), // Pipeline stages for CARRYINSEL (0-1) + .CREG(0), // Pipeline stages for C (0-1) + .DREG(0), // Pipeline stages for D (0-1) + .INMODEREG(1), // Pipeline stages for INMODE (0-1) + .MREG(1), // Multiplier pipeline stages (0-1) + .OPMODEREG(1), // Pipeline stages for OPMODE (0-1) + .PREG(PREG), // Number of pipeline stages for P (0-1) + .RESET_MODE("SYNC") // Selection of synchronous or asynchronous reset. (ASYNC, SYNC). + ) + DSP58_inst ( + // Cascade outputs: Cascade Ports + .ACOUT(), // 34-bit output: A port cascade + .BCOUT((i % CC_LEN == CC_LEN-1) ? tmp_cc[i+1][j] : Wc_int[i+1][j]), // 24-bit output: B cascade + .CARRYCASCOUT(), // 1-bit output: Cascade carry + .MULTSIGNOUT(), // 1-bit output: Multiplier sign cascade + .PCOUT() , // 58-bit output: Cascade output + // Control outputs: Control Inputs/Status Bits + .OVERFLOW(), // 1-bit output: Overflow in add/acc + .PATTERNBDETECT(), // 1-bit output: Pattern bar detect + .PATTERNDETECT(), // 1-bit output: Pattern detect + .UNDERFLOW(), // 1-bit output: Underflow in add/acc + // Data outputs: Data Ports + .CARRYOUT(), // 4-bit output: Carry + .P(pout[i][j]), // 58-bit output: Primary data + .XOROUT(), // 8-bit output: XOR data + // Cascade inputs: Cascade Ports + .ACIN('x), // 34-bit input: A cascade data + .BCIN((i % CC_LEN == 0) ? 'x : Wc_int[i][j]), // 24-bit input: B cascade + .CARRYCASCIN('x), // 1-bit input: Cascade carry + .MULTSIGNIN('x), // 1-bit input: Multiplier sign cascade + .PCIN('x), // 58-bit input: P cascade + // Control inputs: Control Inputs/Status Bits + .ALUMODE(4'h0), // 4-bit input: ALU control + .CARRYINSEL('0), // 3-bit input: Carry select + .CLK(clk), // 1-bit input: Clock + .INMODE({5'b10001}), // 5-bit input: INMODE control + .NEGATE('0), // 3-bit input: Negates the input of the multiplier + .OPMODE({ + Lc[i][DSP_PIPELINE_STAGES-1] ? 2'b00 : 2'b01, + 7'b000_0000 + }), // 9-bit input: Operation mode + // Data inputs: Data Ports + .A({ 7'b0, Ac_int[i] }), // 34-bit input: A data + .B((i % CC_LEN == 0) ? Wc_int[i][j] : 'x), // 24-bit input: B data + .C('x), // 58-bit input: C data + .CARRYIN('0), // 1-bit input: Carry-in + .D('x), // 27-bit input: D data + // Reset/Clock Enable inputs: Reset/Clock Enable Inputs + .ASYNC_RST('0), // 1-bit input: Asynchronous reset for all registers. + .CEA1(ivld), // 1-bit input: Clock enable for 1st stage AREG + .CEA2('0), // 1-bit input: Clock enable for 2nd stage AREG + .CEAD('0), // 1-bit input: Clock enable for ADREG + .CEALUMODE('0), // 1-bit input: Clock enable for ALUMODE + .CEB1(ivld), // 1-bit input: Clock enable for 1st stage BREG + .CEB2('0), // 1-bit input: Clock enable for 2nd stage BREG + .CEC('0), // 1-bit input: Clock enable for CREG + .CECARRYIN('0), // 1-bit input: Clock enable for CARRYINREG + .CECTRL(ivld), // 1-bit input: Clock enable for OPMODEREG and CARRYINSELREG + .CED('0), // 1-bit input: Clock enable for DREG + .CEINMODE('1), // 1-bit input: Clock enable for INMODEREG + .CEM(ivld), // 1-bit input: Clock enable for MREG + .CEP(ivld), // 1-bit input: Clock enable for PREG + .RSTA(rst), // 1-bit input: Reset for AREG + .RSTALLCARRYIN('0), // 1-bit input: Reset for CARRYINREG + .RSTALUMODE('0), // 1-bit input: Reset for ALUMODEREG + .RSTB(rst), // 1-bit input: Reset for BREG + .RSTC('0), // 1-bit input: Reset for CREG + .RSTCTRL(rst), // 1-bit input: Reset for OPMODEREG and CARRYINSELREG + .RSTD('0), // 1-bit input: Reset for DREG and ADREG + .RSTINMODE(rst), // 1-bit input: Reset for INMODE register + .RSTM(rst), // 1-bit input: Reset for MREG + .RSTP(rst) // 1-bit input: Reset for PREG + ); + + + if(i % CC_LEN == CC_LEN-1) begin + sft_reg #( + .N(CC_LEN) + ) inst_sft_reg ( + .clk(clk), + .ivld(ivld), + .din(Wc_int[i-(CC_LEN-1)][j]), + .dout(Wc_int[i+1][j]) + ); + end + + end + end + // end + +// Collect +//------------------------------------------------------------------------------------ + logic [CLEN-1:0][PE-1:0][ACCU_WIDTH-1:0] Pc; + logic [CLEN-1:0] Pc_vld; + + always_ff @(posedge clk) begin + if(rst) begin + for(int i = 0; i < CLEN; i++) begin + Pc[i] <= '0; + end + end else begin + for(int i = 0; i < CLEN; i++) begin + if(ivld) begin + if(i == CLEN-1) begin + Pc[i] <= pout[i]; + Pc_vld[i] <= Lc[i][DSP_PIPELINE_STAGES]; + end else begin + Pc[i] <= Lc[i][DSP_PIPELINE_STAGES] ? pout[i] : Pc[i+1]; + Pc_vld[i] <= Lc[i][DSP_PIPELINE_STAGES] ? 1'b1 : Pc_vld[i+1]; + end + end + end + end + end + + logic ovld; + logic [PE-1:0][ACCU_WIDTH-1:0] p; + + assign ovld = Pc_vld[0]; + assign p = Pc[0]; + + collect_out_1d #( + .PE(PE), + .ACCU_WIDTH(ACCU_WIDTH) + ) inst_collect_out ( + .clk(clk), .rst(rst), + .en(en), + .p_tdata(p), .p_tvalid(ovld), + .m_axis_tdata(m_axis_tdata), .m_axis_tvalid(m_axis_tvalid), .m_axis_tready(m_axis_tready) + ); + +endmodule diff --git a/finn-rtllib/mvu_tiled/mmu/1d/sched_weights_1d.sv b/finn-rtllib/mvu_tiled/mmu/1d/sched_weights_1d.sv new file mode 100644 index 0000000000..de2689e957 --- /dev/null +++ b/finn-rtllib/mvu_tiled/mmu/1d/sched_weights_1d.sv @@ -0,0 +1,141 @@ +/****************************************************************************** + * Copyright (C) 2024, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + *****************************************************************************/ + +module sched_weights_1d #( + int unsigned CU_SIMD, + int unsigned PE, + int unsigned WEIGHT_WIDTH, + + int unsigned N_BEATS_OP, + int unsigned N_BEATS_EP, + + int unsigned N_DCPL_STAGES = 2 +)( + // Global Control + input logic clk, + input logic rst, + + // Input Stream + input logic [PE-1:0][CU_SIMD-1:0][WEIGHT_WIDTH-1:0] s_axis_tdata, + input logic s_axis_tvalid, + output logic s_axis_tready, + + // Output Stream + output logic [PE-1:0][CU_SIMD-1:0][WEIGHT_WIDTH-1:0] m_axis_tdata, + output logic m_axis_tvalid, + input logic m_axis_tready +); + +// Params +// --------------------------------------------------------------------- + localparam integer CNT_EPLG_BITS = (N_BEATS_OP > N_BEATS_EP) ? + (N_BEATS_OP == 1) ? 1 : $clog2(N_BEATS_OP) : + (N_BEATS_EP == 1) ? 1 : $clog2(N_BEATS_EP); + +// Queueing +// --------------------------------------------------------------------- + logic s_out_tready, s_out_tvalid; + logic [PE-1:0][CU_SIMD-1:0][WEIGHT_WIDTH-1:0] s_out_tdata; + + skid #(.DATA_WIDTH(PE*CU_SIMD*WEIGHT_WIDTH), .FEED_STAGES(N_DCPL_STAGES)) inst_ireg ( + .clk(clk), .rst(rst), + .idat(s_axis_tdata), .ivld(s_axis_tvalid), .irdy(s_axis_tready), + .odat(s_out_tdata), .ovld(s_out_tvalid), .ordy(s_out_tready) + ); + +// Shifting +// --------------------------------------------------------------------- + logic valid_C = '0, valid_N; + logic eplg_C = '0, eplg_N; + logic [CNT_EPLG_BITS-1:0] cnt_eplg_C = '0, cnt_eplg_N; + + logic ovld, ordy; + logic [PE-1:0][CU_SIMD-1:0][WEIGHT_WIDTH-1:0] odat; + + // REG + always_ff @(posedge clk) begin + if(rst) begin + valid_C <= '1; + eplg_C <= 1'b0; + cnt_eplg_C <= '0; + end else begin + valid_C <= valid_N; + eplg_C <= eplg_N; + cnt_eplg_C <= cnt_eplg_N; + end + end + + // DP + always_comb begin + valid_N = valid_C; + eplg_N = eplg_C; + cnt_eplg_N = cnt_eplg_C; + + // Read + if (ovld && ordy) begin + // Shift ctrl + if(eplg_C) begin + if(cnt_eplg_C == N_BEATS_EP-1) begin + eplg_N = 1'b0; + cnt_eplg_N = 0; + valid_N = 1'b1; + end else begin + cnt_eplg_N = cnt_eplg_C + 1; + valid_N = 1'b0; + end + end else begin + if(cnt_eplg_C == N_BEATS_OP-1) begin + eplg_N = 1'b1; + cnt_eplg_N = 0; + valid_N = 1'b0; + end else begin + cnt_eplg_N = cnt_eplg_C + 1; + valid_N = 1'b1; + end + end + end + end + + // Output valid + assign ovld = !((s_out_tvalid && valid_C) != valid_C); + assign s_out_tready = (ovld && ordy) && valid_C; + assign odat = valid_C ? s_out_tdata : '0; + +// Oreg +// --------------------------------------------------------------------- + skid #(.DATA_WIDTH(PE*CU_SIMD*WEIGHT_WIDTH), .FEED_STAGES(N_DCPL_STAGES)) inst_oreg ( + .clk(clk), .rst(rst), + .idat(odat), .ivld(ovld), .irdy(ordy), + .odat(m_axis_tdata), .ovld(m_axis_tvalid), .ordy(m_axis_tready) + ); + +endmodule diff --git a/finn-rtllib/mvu_tiled/mmu/1d/sft_reg.sv b/finn-rtllib/mvu_tiled/mmu/1d/sft_reg.sv new file mode 100644 index 0000000000..65e674e53b --- /dev/null +++ b/finn-rtllib/mvu_tiled/mmu/1d/sft_reg.sv @@ -0,0 +1,25 @@ +module sft_reg #( + int N = 4, + int DATA_BITS = 24 +)( + input logic clk, + input logic ivld, + input logic [DATA_BITS-1:0] din, + output logic [DATA_BITS-1:0] dout +); + + // A 2D array representing the shift stages + logic [N-1:0][DATA_BITS-1:0] shift_pipe; + + always_ff @(posedge clk) begin + if (ivld) begin + // Shift the bits in + shift_pipe <= {shift_pipe[N-2:0], din}; + end + end + + // The tool sees this lack of reset and constant index + // and maps it to an SRL16 automatically. + assign dout = shift_pipe[N-1]; + +endmodule diff --git a/finn-rtllib/mvu_tiled/mmu/2d/collect_out_2d.sv b/finn-rtllib/mvu_tiled/mmu/2d/collect_out_2d.sv new file mode 100644 index 0000000000..b1949e0620 --- /dev/null +++ b/finn-rtllib/mvu_tiled/mmu/2d/collect_out_2d.sv @@ -0,0 +1,120 @@ +/****************************************************************************** + * Copyright (C) 2024, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + *****************************************************************************/ + +module collect_out_2d #( + int unsigned PE, + int unsigned ACCU_WIDTH, + + int unsigned QDEPTH = 2 * PE, + int unsigned QCNT_BITS = $clog2(QDEPTH), + int unsigned Q_MAX = PE, + + int unsigned N_DCPL_STAGES = 2 +)( + // Global Control + input logic clk, + input logic rst, + + output logic en, + + // Input Stream + input logic [PE-1:0][ACCU_WIDTH-1:0] p_tdata, + input logic [PE-1:0] p_tvalid, + + // Output Stream + output logic [PE-1:0][ACCU_WIDTH-1:0] m_axis_tdata, + output logic m_axis_tvalid, + input logic m_axis_tready +); + +// Queueing +// --------------------------------------------------------------------- + logic [PE-1:0] q_in_tready, q_in_tvalid; + logic [PE-1:0][ACCU_WIDTH-1:0] q_in_tdata; + logic [PE-1:0] q_out_tready, q_out_tvalid; + logic [PE-1:0][ACCU_WIDTH-1:0] q_out_tdata; + logic [PE-1:0][QCNT_BITS-1:0] q_count; + logic en_int; + + assign q_in_tvalid = p_tvalid; + assign q_in_tdata = p_tdata; + + for(genvar i = 0; i < PE; i++) begin + Q_srl #( + .depth(QDEPTH), + .width(ACCU_WIDTH) + ) inst_queue ( + .clock(clk), .reset(rst), + .count(q_count[i]), .maxcount(), + .i_v(q_in_tvalid[i]), .i_r(q_in_tready[i]), .i_d(q_in_tdata[i]), + .o_v(q_out_tvalid[i]), .o_r(q_out_tready[i]), .o_d(q_out_tdata[i]) + ); + end + + // Global enable + always_comb begin + en_int = 1'b1; + + for(int i = 0; i < PE; i++) begin + if(q_count[i] > Q_MAX) + en_int = 1'b0; + end + end + + always_ff @( posedge clk ) begin + if(rst) begin + en <= 1'b0; + end + else begin + en <= en_int; + end + end + +// Output +// --------------------------------------------------------------------- + logic ovld; + logic ordy; + logic [PE-1:0][ACCU_WIDTH-1:0] odat; + + assign odat = q_out_tdata; + assign ovld = &q_out_tvalid; + for(genvar i = 0; i < PE; i++) begin + assign q_out_tready[i] = ovld && ordy; + end + + skid #(.DATA_WIDTH(PE*ACCU_WIDTH), .FEED_STAGES(N_DCPL_STAGES)) inst_oreg ( + .clk(clk), .rst(rst), + .idat(odat), .ivld(ovld), .irdy(ordy), + .odat(m_axis_tdata), .ovld(m_axis_tvalid), .ordy(m_axis_tready) + ); + +endmodule diff --git a/finn-rtllib/mvu_tiled/mmu/2d/cu_mmau_2d.sv b/finn-rtllib/mvu_tiled/mmu/2d/cu_mmau_2d.sv new file mode 100644 index 0000000000..40379ea16f --- /dev/null +++ b/finn-rtllib/mvu_tiled/mmu/2d/cu_mmau_2d.sv @@ -0,0 +1,432 @@ +/****************************************************************************** + * Copyright (C) 2025, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * @brief Compute unit (DSP grid) - MMAU + * @author Dario Korolija + *****************************************************************************/ + +module cu_mmau_2d #( + int unsigned PE, + int unsigned CLEN, + int unsigned CU_SIMD, + + int unsigned ACTIVATION_WIDTH, + int unsigned WEIGHT_WIDTH, + int unsigned ACCU_WIDTH, + + bit SIGNED_ACTIVATIONS = 1, + int unsigned FORCE_BEHAVIOURAL = 0 + ) ( + // Global Control + input logic clk, + input logic rst, + + // Enable + output logic en, + + // Input + input logic ivld, + input logic [CLEN-1:0] ilast, + input logic [CLEN-1:0][CU_SIMD-1:0][ACTIVATION_WIDTH-1:0] a, + input logic [PE-1:0][CU_SIMD-1:0][WEIGHT_WIDTH-1:0] w, + + // Ouput + output logic m_axis_tvalid, + input logic m_axis_tready, + output logic [PE-1:0][ACCU_WIDTH-1:0] m_axis_tdata + ); + + +// Startup Recovery Watchdog +// The DSP slice needs 100ns of recovery time after initial startup before +// being able to ingest input properly. This watchdog discovers violating +// stimuli during simulation and produces a corresponding warning. +//------------------------------------------------------------------------------------ + if(1) begin : blkRecoveryWatch + logic Dirty = 1; + initial begin + #100ns; + Dirty <= 0; + end + + always_ff @(posedge clk) begin + assert(!Dirty || rst) else begin + $warning("%m: Feeding input during DSP startup recovery. Expect functional errors."); + end + end + end : blkRecoveryWatch + +// Shifts - activations and weights +//------------------------------------------------------------------------------------ + localparam int unsigned PAD_BITS_ACT = 9 - ACTIVATION_WIDTH; + localparam int unsigned PAD_BITS_WEIGHT = 8 - WEIGHT_WIDTH; + + logic [CLEN-1:0][PE-1:0][CU_SIMD*ACTIVATION_WIDTH-1:0] Ac; + logic [CLEN-1:0][PE-1:0] Ac_last; + logic [CLEN-1:0][PE-1:0][CU_SIMD*WEIGHT_WIDTH-1:0] Wc; + + for(genvar i = 0; i < CLEN; i++) begin + assign Ac[i][0] = a[i]; + assign Ac_last[i][0] = ilast[i]; + end + + for(genvar i = 0; i < PE; i++) begin + assign Wc[0][i] = w[i]; + end + + always_ff @(posedge clk) begin + if(rst) begin + for(int i = 0; i < CLEN; i++) begin + for(int j = 1; j < PE; j++) begin + Ac[i][j] <= '0; + Ac_last[i][j] <= 1'b0; + end + end + + for(int i = 1; i < CLEN; i++) begin + for(int j = 0; j < PE; j++) begin + Wc[i][j] <= '0; + end + end + end + for(int i = 0; i < CLEN; i++) begin + for(int j = 1; j < PE; j++) begin + if(ivld) begin + Ac[i][j] <= Ac[i][j-1]; + Ac_last[i][j] <= Ac_last[i][j-1]; + end + end + end + + for(int i = 1; i < CLEN; i++) begin + for(int j = 0; j < PE; j++) begin + if(ivld) begin + Wc[i][j] <= Wc[i-1][j]; + end + end + end + end + +// Shifts - per DSP +//------------------------------------------------------------------------------------ + localparam int unsigned DSP_PIPELINE_STAGES = 3; + logic [CLEN-1:0][PE-1:0][DSP_PIPELINE_STAGES:0] Lc; + + for(genvar i = 0; i < CLEN; i++) begin + for(genvar j = 0; j < PE; j++) begin + assign Lc[i][j][0] = Ac_last[i][j]; + end + end + + always_ff @(posedge clk) begin + if(rst) begin + for(int i = 0; i < CLEN; i++) begin + for(int j = 0; j < PE; j++) begin + for(int k = 1; k <= DSP_PIPELINE_STAGES; k++) begin + Lc[i][j][k] <= 1'b0; + end + end + end + end + else begin + for(int i = 0; i < CLEN; i++) begin + for(int j = 0; j < PE; j++) begin + for(int k = 1; k <= DSP_PIPELINE_STAGES; k++) begin + if(ivld) begin + Lc[i][j][k] <= Lc[i][j][k-1]; + end + end + end + end + end + end + +// Instantiate PE x CLEN DSPs +//------------------------------------------------------------------------------------ + logic [CLEN-1:0][PE-1:0][ACCU_WIDTH-1:0] pout; + + if(FORCE_BEHAVIOURAL == 1) begin + logic [CLEN-1:0][PE-1:0][CU_SIMD*ACTIVATION_WIDTH-1:0] Ac_int; + logic [CLEN-1:0][PE-1:0][CU_SIMD*WEIGHT_WIDTH-1:0] Wc_int; + logic [CLEN-1:0][PE-1:0][CU_SIMD-1:0][ACCU_WIDTH-1:0] Mc_int_part; + logic [CLEN-1:0][PE-1:0][ACCU_WIDTH-1:0] Mc_int_sum; + logic [CLEN-1:0][PE-1:0][ACCU_WIDTH-1:0] Mc_int; + + + for (genvar i = 0; i < CLEN; i++) begin + for (genvar j = 0; j < PE; j++) begin + always_comb begin + Mc_int_sum[i][j] = 0; + + for(int k = 0; k < CU_SIMD; k++) begin + Mc_int_part[i][j][k] = $signed(Ac_int[i][j][k*ACTIVATION_WIDTH+:ACTIVATION_WIDTH]) * $signed(Wc_int[i][j][k*WEIGHT_WIDTH+:WEIGHT_WIDTH]); + Mc_int_sum[i][j] = $signed(Mc_int_sum[i][j]) + $signed(Mc_int_part[i][j][k]); + end + end + + always_ff @(posedge clk) begin + if(rst) begin + Ac_int[i][j] <= '0; + Wc_int[i][j] <= '0; + Mc_int[i][j] <= '0; + pout[i][j] <= '0; + end else begin + if(ivld) begin + Ac_int[i][j] <= Ac[i][j]; + Wc_int[i][j] <= Wc[i][j]; + Mc_int[i][j] <= $signed(Mc_int_sum[i][j]); + pout[i][j] <= Lc[i][j][DSP_PIPELINE_STAGES] ? $signed(Mc_int[i][j]) : $signed(Mc_int[i][j]) + $signed(pout[i][j]); + end + end + end + end + end + end else begin + localparam int INTERNAL_REGS = 1; // 1 : 0 + localparam bit PREG = 1; + + logic [CLEN-1:0][PE-1:0][26:0] Ac_int; + logic [CLEN-1:0][PE-1:0][23:0] Wc_int; + + for (genvar i = 0; i < CLEN; i++) begin + for (genvar j = 0; j < PE; j++) begin + + for (genvar k = 0; k < CU_SIMD; k++) begin + assign Ac_int[i][j][9*k +: 9] = + SIGNED_ACTIVATIONS ? PAD_BITS_ACT == 0 ? Ac[i][j][ACTIVATION_WIDTH*k+:ACTIVATION_WIDTH] : { {PAD_BITS_ACT{Ac[i][j][k*ACTIVATION_WIDTH+ACTIVATION_WIDTH-1]}}, Ac[i][j][k*ACTIVATION_WIDTH+:ACTIVATION_WIDTH] } + : PAD_BITS_ACT == 0 ? Ac[i][j][ACTIVATION_WIDTH*k+:ACTIVATION_WIDTH] : { {PAD_BITS_ACT{1'b0}}, Ac[i][j][k*ACTIVATION_WIDTH+:ACTIVATION_WIDTH] } ; + assign Wc_int[i][j][8*k +: 8] = + PAD_BITS_WEIGHT == 0 ? Wc[i][j][WEIGHT_WIDTH*k+:WEIGHT_WIDTH] : { {PAD_BITS_WEIGHT{Wc[i][j][k*WEIGHT_WIDTH+WEIGHT_WIDTH-1]}}, Wc[i][j][k*WEIGHT_WIDTH+:WEIGHT_WIDTH] }; + end + + + DSP58 #( + // Feature Control Attributes: Data Path Selection + .AMULTSEL("A"), // Selects A input to multiplier (A, AD) + .A_INPUT("DIRECT"), // Selects A input source, "DIRECT" (A port) or "CASCADE" (ACIN port) + .BMULTSEL("B"), // Selects B input to multiplier (AD, B) + .B_INPUT("DIRECT"), // Selects B input source, "DIRECT" (B port) or "CASCADE" (BCIN port) + .DSP_MODE("INT8"), // Configures DSP to a particular mode of operation. Set to INT24 for + // legacy mode. + .PREADDINSEL("A"), // Selects input to pre-adder (A, B) + .RND(58'h000000000000000), // Rounding Constant + .USE_MULT("MULTIPLY"), // Select multiplier usage (DYNAMIC, MULTIPLY, NONE) + .USE_SIMD("ONE58"), // SIMD selection (FOUR12, ONE58, TWO24) + .USE_WIDEXOR("FALSE"), // Use the Wide XOR function (FALSE, TRUE) + .XORSIMD("XOR24_34_58_116"), // Mode of operation for the Wide XOR (XOR12_22, XOR24_34_58_116) + // Pattern Detector Attributes: Pattern Detection Configuration + .AUTORESET_PATDET("NO_RESET"), // NO_RESET, RESET_MATCH, RESET_NOT_MATCH + .AUTORESET_PRIORITY("RESET"), // Priority of AUTORESET vs. CEP (CEP, RESET). + .MASK(58'h0ffffffffffffff), // 58-bit mask value for pattern detect (1=ignore) + .PATTERN(58'h000000000000000), // 58-bit pattern match for pattern detect + .SEL_MASK("MASK"), // C, MASK, ROUNDING_MODE1, ROUNDING_MODE2 + .SEL_PATTERN("PATTERN"), // Select pattern value (C, PATTERN) + .USE_PATTERN_DETECT("NO_PATDET"), // Enable pattern detect (NO_PATDET, PATDET) + // Programmable Inversion Attributes: Specifies built-in programmable inversion on specific pins + .IS_ALUMODE_INVERTED(4'b0000), // Optional inversion for ALUMODE + .IS_CARRYIN_INVERTED(1'b0), // Optional inversion for CARRYIN + .IS_CLK_INVERTED(1'b0), // Optional inversion for CLK + .IS_INMODE_INVERTED(5'b00000), // Optional inversion for INMODE + .IS_NEGATE_INVERTED(3'b000), // Optional inversion for NEGATE + .IS_OPMODE_INVERTED({2'b00, // W: LAST ? 0 : P + 3'b000, // Z: 0 + 2'b01, // Y : M + 2'b01 // X: M + }), // Optional inversion for OPMODE + .IS_RSTALLCARRYIN_INVERTED(1'b0), // Optional inversion for RSTALLCARRYIN + .IS_RSTALUMODE_INVERTED(1'b0), // Optional inversion for RSTALUMODE + .IS_RSTA_INVERTED(1'b0), // Optional inversion for RSTA + .IS_RSTB_INVERTED(1'b0), // Optional inversion for RSTB + .IS_RSTCTRL_INVERTED(1'b0), // Optional inversion for STCONJUGATE_A + .IS_RSTC_INVERTED(1'b0), // Optional inversion for RSTC + .IS_RSTD_INVERTED(1'b0), // Optional inversion for RSTD + .IS_RSTINMODE_INVERTED(1'b0), // Optional inversion for RSTINMODE + .IS_RSTM_INVERTED(1'b0), // Optional inversion for RSTM + .IS_RSTP_INVERTED(1'b0), // Optional inversion for RSTP + // Register Control Attributes: Pipeline Register Configuration + .ACASCREG(INTERNAL_REGS), // Number of pipeline stages between A/ACIN and ACOUT (0-2) + .ADREG(0), // Pipeline stages for pre-adder (0-1) + .ALUMODEREG(0), // Pipeline stages for ALUMODE (0-1) + .AREG(INTERNAL_REGS), // Pipeline stages for A (0-2) + .BCASCREG(INTERNAL_REGS), // Number of pipeline stages between B/BCIN and BCOUT (0-2) + .BREG(INTERNAL_REGS), // Pipeline stages for B (0-2) + .CARRYINREG(0), // Pipeline stages for CARRYIN (0-1) + .CARRYINSELREG(0), // Pipeline stages for CARRYINSEL (0-1) + .CREG(0), // Pipeline stages for C (0-1) + .DREG(0), // Pipeline stages for D (0-1) + .INMODEREG(1), // Pipeline stages for INMODE (0-1) + .MREG(1), // Multiplier pipeline stages (0-1) + .OPMODEREG(1), // Pipeline stages for OPMODE (0-1) + .PREG(PREG), // Number of pipeline stages for P (0-1) + .RESET_MODE("SYNC") // Selection of synchronous or asynchronous reset. (ASYNC, SYNC). + ) + DSP58_inst ( + // Cascade outputs: Cascade Ports + .ACOUT(), // 34-bit output: A port cascade + .BCOUT(), // 24-bit output: B cascade + .CARRYCASCOUT(), // 1-bit output: Cascade carry + .MULTSIGNOUT(), // 1-bit output: Multiplier sign cascade + .PCOUT(), // 58-bit output: Cascade output + // Control outputs: Control Inputs/Status Bits + .OVERFLOW(), // 1-bit output: Overflow in add/acc + .PATTERNBDETECT(), // 1-bit output: Pattern bar detect + .PATTERNDETECT(), // 1-bit output: Pattern detect + .UNDERFLOW(), // 1-bit output: Underflow in add/acc + // Data outputs: Data Ports + .CARRYOUT(), // 4-bit output: Carry + .P(pout[i][j]), // 58-bit output: Primary data + .XOROUT(), // 8-bit output: XOR data + // Cascade inputs: Cascade Ports + .ACIN('x), // 34-bit input: A cascade data + .BCIN('x), // 24-bit input: B cascade + .CARRYCASCIN('x), // 1-bit input: Cascade carry + .MULTSIGNIN('x), // 1-bit input: Multiplier sign cascade + .PCIN('0), // 58-bit input: P cascade + // Control inputs: Control Inputs/Status Bits + .ALUMODE(4'h0), // 4-bit input: ALU control + .CARRYINSEL('0), // 3-bit input: Carry select + .CLK(clk), // 1-bit input: Clock + .INMODE({5'b10001}), // 5-bit input: INMODE control + .NEGATE('0), // 3-bit input: Negates the input of the multiplier + .OPMODE({ + Lc[i][j][DSP_PIPELINE_STAGES-1] ? 2'b00 : 2'b01, + 7'b000_0000 + }), // 9-bit input: Operation mode + // Data inputs: Data Ports + .A({ 7'b0, Ac_int[i][j] }), // 34-bit input: A data + .B(Wc_int[i][j]), // 24-bit input: B data + .C('x), // 58-bit input: C data + .CARRYIN('0), // 1-bit input: Carry-in + .D('x), // 27-bit input: D data + // Reset/Clock Enable inputs: Reset/Clock Enable Inputs + .ASYNC_RST('0), // 1-bit input: Asynchronous reset for all registers. + .CEA1(ivld), // 1-bit input: Clock enable for 1st stage AREG + .CEA2('0), // 1-bit input: Clock enable for 2nd stage AREG + .CEAD('0), // 1-bit input: Clock enable for ADREG + .CEALUMODE('0), // 1-bit input: Clock enable for ALUMODE + .CEB1(ivld), // 1-bit input: Clock enable for 1st stage BREG + .CEB2('0), // 1-bit input: Clock enable for 2nd stage BREG + .CEC('0), // 1-bit input: Clock enable for CREG + .CECARRYIN('0), // 1-bit input: Clock enable for CARRYINREG + .CECTRL(ivld), // 1-bit input: Clock enable for OPMODEREG and CARRYINSELREG + .CED('0), // 1-bit input: Clock enable for DREG + .CEINMODE('1), // 1-bit input: Clock enable for INMODEREG + .CEM(ivld), // 1-bit input: Clock enable for MREG + .CEP(ivld), // 1-bit input: Clock enable for PREG + .RSTA(rst), // 1-bit input: Reset for AREG + .RSTALLCARRYIN('0), // 1-bit input: Reset for CARRYINREG + .RSTALUMODE('0), // 1-bit input: Reset for ALUMODEREG + .RSTB(rst), // 1-bit input: Reset for BREG + .RSTC('0), // 1-bit input: Reset for CREG + .RSTCTRL(rst), // 1-bit input: Reset for OPMODEREG and CARRYINSELREG + .RSTD('0), // 1-bit input: Reset for DREG and ADREG + .RSTINMODE(rst), // 1-bit input: Reset for INMODE register + .RSTM(rst), // 1-bit input: Reset for MREG + .RSTP(rst) // 1-bit input: Reset for PREG + ); + + end + end + end + +// Collect +//------------------------------------------------------------------------------------ + logic [CLEN-1:0][PE-1:0][ACCU_WIDTH-1:0] Pc_C = '0, Pc_N; + logic [CLEN-1:0][PE-1:0] Pc_vld_C = '0, Pc_vld_N; + + always_ff @(posedge clk) begin + if(rst) begin + Pc_C <= 'X; + Pc_vld_C <= '0; + end else begin + if(ivld) begin + Pc_C <= Pc_N; + Pc_vld_C <= Pc_vld_N; + end + end + end + + for(genvar i = 0; i < CLEN; i++) begin + for(genvar j = 0; j < PE; j++) begin + if(i == CLEN-1) begin + assign Pc_N[i][j] = pout[i][j]; + assign Pc_vld_N[i][j] = Lc[i][j][DSP_PIPELINE_STAGES]; + end else begin + assign Pc_N[i][j] = Lc[i][j][DSP_PIPELINE_STAGES] ? pout[i][j] : Pc_C[i+1][j]; + assign Pc_vld_N[i][j] = Lc[i][j][DSP_PIPELINE_STAGES] ? 1'b1 : Pc_vld_C[i+1][j]; + end + end + end + /* + always_ff @(posedge clk) begin + if(rst) begin + for(int i = 0; i < CLEN; i++) begin + for(int j = 0; j < PE; j++) begin + Pc[i][j] <= '0; + Pc_vld[i][j] <= '0; + end + end + end else begin + for(int i = 0; i < CLEN; i++) begin + for(int j = 0; j < PE; j++) begin + if(ivld) begin + if(i == CLEN-1) begin + Pc[i][j] <= pout[i][j]; + Pc_vld[i][j] <= Lc[i][j][DSP_PIPELINE_STAGES]; + end else begin + Pc[i][j] <= Lc[i][j][DSP_PIPELINE_STAGES] ? pout[i][j] : Pc[i+1][j]; + Pc_vld[i][j] <= Lc[i][j][DSP_PIPELINE_STAGES] ? 1'b1 : Pc_vld[i+1][j]; + end + end + end + end + end + end + */ + + logic [PE-1:0] ovld; + logic [PE-1:0][ACCU_WIDTH-1:0] p; + + for(genvar i = 0; i < PE; i++) begin + assign ovld[i] = Pc_vld_C[0][i]; + assign p[i] = Pc_C[0][i]; + end + + collect_out_2d #( + .PE(PE), + .ACCU_WIDTH(ACCU_WIDTH) + ) inst_collect_out ( + .clk(clk), .rst(rst), + .en(en), + .p_tdata(p), .p_tvalid(ovld), + .m_axis_tdata(m_axis_tdata), .m_axis_tvalid(m_axis_tvalid), .m_axis_tready(m_axis_tready) + ); + +endmodule diff --git a/finn-rtllib/mvu_tiled/mmu/2d/sched_weights_2d.sv b/finn-rtllib/mvu_tiled/mmu/2d/sched_weights_2d.sv new file mode 100644 index 0000000000..93cfad09b5 --- /dev/null +++ b/finn-rtllib/mvu_tiled/mmu/2d/sched_weights_2d.sv @@ -0,0 +1,165 @@ +/****************************************************************************** + * Copyright (C) 2024, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + *****************************************************************************/ + +module sched_weights_2d #( + int unsigned CU_SIMD, + int unsigned PE, + int unsigned WEIGHT_WIDTH, + + int unsigned N_BEATS_OP, + int unsigned N_BEATS_EP, + + int unsigned N_DCPL_STAGES = 2 +)( + // Global Control + input logic clk, + input logic rst, + + // Input Stream + input logic [PE-1:0][CU_SIMD-1:0][WEIGHT_WIDTH-1:0] s_axis_tdata, + input logic s_axis_tvalid, + output logic s_axis_tready, + + // Output Stream + output logic [PE-1:0][CU_SIMD-1:0][WEIGHT_WIDTH-1:0] m_axis_tdata, + output logic m_axis_tvalid, + input logic m_axis_tready +); + +// Params +// --------------------------------------------------------------------- + localparam integer CNT_EPLG_BITS = (N_BEATS_OP > N_BEATS_EP) ? + (N_BEATS_OP == 1) ? 1 : $clog2(N_BEATS_OP) : + (N_BEATS_EP == 1) ? 1 : $clog2(N_BEATS_EP); + +// Queueing +// --------------------------------------------------------------------- + logic [PE-1:0] q_in_tready, q_in_tvalid; + logic [PE-1:0][CU_SIMD-1:0][WEIGHT_WIDTH-1:0] q_in_tdata; + logic [PE-1:0] q_out_tready, q_out_tvalid; + logic [PE-1:0][CU_SIMD-1:0][WEIGHT_WIDTH-1:0] q_out_tdata; + + assign s_axis_tready = &q_in_tready; + + for(genvar i = 0; i < PE; i++) begin + assign q_in_tdata[i] = s_axis_tdata[i]; + assign q_in_tvalid[i] = s_axis_tvalid && s_axis_tready; + + Q_srl #( + .depth(2*PE), .width(CU_SIMD*WEIGHT_WIDTH) + ) inst_queue ( + .clock(clk), .reset(rst), + .i_v(q_in_tvalid[i]), .i_r(q_in_tready[i]), .i_d(q_in_tdata[i]), + .o_v(q_out_tvalid[i]), .o_r(q_out_tready[i]), .o_d(q_out_tdata[i]) + ); + end + +// Shifting +// --------------------------------------------------------------------- + logic [PE-1:0] valid_C = '0, valid_N; + logic eplg_C = '0, eplg_N; + logic [CNT_EPLG_BITS-1:0] cnt_eplg_C = '0, cnt_eplg_N; + + logic ovld, ordy; + logic [PE-1:0][CU_SIMD-1:0][WEIGHT_WIDTH-1:0] odat; + + // REG + always_ff @(posedge clk) begin + if(rst) begin + valid_C[0] <= 1'b1; + valid_C[PE-1:1] <= '0; + eplg_C <= 1'b0; + cnt_eplg_C <= '0; + end else begin + valid_C <= valid_N; + eplg_C <= eplg_N; + cnt_eplg_C <= cnt_eplg_N; + end + end + + // DP + always_comb begin + valid_N = valid_C; + eplg_N = eplg_C; + cnt_eplg_N = cnt_eplg_C; + + // Read + if (ovld && ordy) begin + // Shift ctrl + valid_N[PE-1:1] = valid_C[PE-2:0]; + if(eplg_C) begin + if(cnt_eplg_C == N_BEATS_EP-1) begin + eplg_N = 1'b0; + cnt_eplg_N = 0; + valid_N[0] = 1'b1; + end else begin + cnt_eplg_N = cnt_eplg_C + 1; + valid_N[0] = 1'b0; + end + end else begin + if(cnt_eplg_C == N_BEATS_OP-1) begin + eplg_N = 1'b1; + cnt_eplg_N = 0; + valid_N[0] = 1'b0; + end else begin + cnt_eplg_N = cnt_eplg_C + 1; + valid_N[0] = 1'b1; + end + end + end + end + + // Output valid + always_comb begin + ovld = 1'b1; + + for(int i = 0; i < PE; i++) begin + if((valid_C[i] & q_out_tvalid[i]) != valid_C[i]) begin + ovld = 1'b0; + end + end + end + + for(genvar i = 0; i < PE; i++) begin + assign q_out_tready[i] = (ovld && ordy) && valid_C[i]; + assign odat[i] = valid_C[i] ? q_out_tdata[i] : '0; + end + +// Oreg +// --------------------------------------------------------------------- + skid #(.DATA_WIDTH(PE*CU_SIMD*WEIGHT_WIDTH), .FEED_STAGES(N_DCPL_STAGES)) inst_oreg ( + .clk(clk), .rst(rst), + .idat(odat), .ivld(ovld), .irdy(ordy), + .odat(m_axis_tdata), .ovld(m_axis_tvalid), .ordy(m_axis_tready) + ); + +endmodule diff --git a/finn-rtllib/mvu_tiled/mmu/en_global.sv b/finn-rtllib/mvu_tiled/mmu/en_global.sv new file mode 100644 index 0000000000..9547660d73 --- /dev/null +++ b/finn-rtllib/mvu_tiled/mmu/en_global.sv @@ -0,0 +1,99 @@ +/****************************************************************************** + * Copyright (C) 2024, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + *****************************************************************************/ + +module en_global #( + int unsigned PE, + int unsigned CLEN, + int unsigned CU_SIMD = 3, + + int unsigned WEIGHT_WIDTH, + int unsigned ACTIVATION_WIDTH, + + int unsigned N_DCPL_STAGES = 2 +)( + // Global Control + input logic clk, + input logic rst, + input logic en, + + // Activation Stream + input logic s_act_tvalid, + output logic s_act_tready, + input logic [CLEN-1:0][CU_SIMD-1:0][ACTIVATION_WIDTH-1:0] s_act_tdata, + input logic [CLEN-1:0] s_act_tlast, + + output logic [CLEN-1:0][CU_SIMD-1:0][ACTIVATION_WIDTH-1:0] m_act_tdata, + output logic [CLEN-1:0] m_act_tlast, + + // Weight Stream + input logic [PE-1:0][CU_SIMD-1:0][WEIGHT_WIDTH-1:0] s_wgt_tdata, + input logic s_wgt_tvalid, + output logic s_wgt_tready, + + output logic [PE-1:0][CU_SIMD-1:0][WEIGHT_WIDTH-1:0] m_wgt_tdata, + + output logic m_tvalid +); + +// Global enable +// --------------------------------------------------------------------- +logic [CLEN-1:0][CU_SIMD-1:0][ACTIVATION_WIDTH-1:0] act_tdata; +logic [CLEN-1:0] act_tlast; +logic [PE-1:0][CU_SIMD-1:0][WEIGHT_WIDTH-1:0] wgt_tdata; +logic ovld; +logic ordy; + +assign ovld = en && s_act_tvalid && s_wgt_tvalid; +assign s_act_tready = en && s_wgt_tvalid; +assign s_wgt_tready = en && s_act_tvalid; + + +assign act_tdata = ovld ? s_act_tdata : '0; +assign act_tlast = ovld ? s_act_tlast : '0; +assign wgt_tdata = ovld ? s_wgt_tdata : '0; + + +// Output +// --------------------------------------------------------------------- +skid #(.DATA_WIDTH(PE*CU_SIMD*WEIGHT_WIDTH), .FEED_STAGES(N_DCPL_STAGES)) inst_oreg_weights ( + .clk(clk), .rst(rst), + .idat(wgt_tdata), .ivld(ovld), .irdy(), + .odat(m_wgt_tdata), .ovld(), .ordy(1'b1) +); + +skid #(.DATA_WIDTH(CLEN*CU_SIMD*ACTIVATION_WIDTH+CLEN), .FEED_STAGES(N_DCPL_STAGES)) inst_oreg_activations ( + .clk(clk), .rst(rst), + .idat({act_tlast, act_tdata}), .ivld(ovld), .irdy(ordy), + .odat({m_act_tlast, m_act_tdata}), .ovld(m_tvalid), .ordy(1'b1) +); + +endmodule diff --git a/finn-rtllib/mvu_tiled/mmu/mmu_axi.sv b/finn-rtllib/mvu_tiled/mmu/mmu_axi.sv new file mode 100644 index 0000000000..77d07d9db2 --- /dev/null +++ b/finn-rtllib/mvu_tiled/mmu/mmu_axi.sv @@ -0,0 +1,269 @@ +/****************************************************************************** + * Copyright (C) 2024, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + *****************************************************************************/ + +module mmu_axi #( + string GEMM_TYPE = "mmau", + int unsigned PE, + int unsigned SIMD, + int unsigned CU_SIMD = 3, + + int unsigned MW, + int unsigned MH, + int unsigned N_VECTORS, + + int unsigned WEIGHT_WIDTH, + int unsigned ACTIVATION_WIDTH, + int unsigned ACCU_WIDTH, + + int unsigned IN_TILED = 0, + int unsigned OUT_TILED = 0, + + int unsigned DSP_STAGES = 3, + bit SIGNED_ACTIVATIONS = 1, + bit PUMPED_COMPUTE = 0, // Not used + bit FORCE_BEHAVIOURAL = 0, + + int unsigned N_DCPL_STAGES = 2, + + // Safely deducible parameters + localparam int unsigned CLEN = (SIMD + CU_SIMD-1)/ CU_SIMD, + localparam int unsigned WSIMD = PE * CU_SIMD, + localparam int unsigned ASIMD = CLEN * CU_SIMD, + + localparam int unsigned WEIGHT_STREAM_WIDTH = WSIMD * WEIGHT_WIDTH, + localparam int unsigned WEIGHT_STREAM_WIDTH_BA = (WEIGHT_STREAM_WIDTH + 7)/8 * 8, + localparam int unsigned INPUT_STREAM_WIDTH = SIMD * ACTIVATION_WIDTH, + localparam int unsigned INPUT_STREAM_WIDTH_BA = (INPUT_STREAM_WIDTH + 7)/8 * 8, + localparam int unsigned OUTPUT_STREAM_WIDTH = PE * ACCU_WIDTH, + localparam int unsigned OUTPUT_STREAM_WIDTH_BA = (OUTPUT_STREAM_WIDTH + 7)/8 * 8 +)( + // Global Control + input logic ap_clk, + input logic ap_clk2x, + input logic ap_rst_n, + + // Weight Stream + input logic [WEIGHT_STREAM_WIDTH_BA-1:0] s_axis_weights_tdata, + input logic s_axis_weights_tvalid, + output logic s_axis_weights_tready, + + // Input Stream + input logic [INPUT_STREAM_WIDTH_BA-1:0] s_axis_input_tdata, + input logic s_axis_input_tvalid, + output logic s_axis_input_tready, + + // Output Stream + output logic [OUTPUT_STREAM_WIDTH_BA-1:0] m_axis_output_tdata, + output logic m_axis_output_tvalid, + input logic m_axis_output_tready +); + +// Checks and params +// --------------------------------------------------------------------- + initial begin + if (SIMD != CLEN * CU_SIMD) begin + $error("%m: SIMD (%0d) should be a multiple of CU_SIMD and CLEN. (TODO: Needs testing)", SIMD); + $finish; + end + if (MW % SIMD != 0) begin + $error("%m: MW (%0d) is not a multiple of SIMD (%0d).", MW, SIMD); + $finish; + end + if (MH % PE != 0) begin + $error("%m: MH (%0d) is not a multiple of PE (%0d).", MH, PE); + $finish; + end + if (WEIGHT_WIDTH > 8) begin + $error("Weight width of %0d-bits exceeds maximum of 8-bits", WEIGHT_WIDTH); + $finish; + end + if (ACTIVATION_WIDTH > 8) begin + $error("Activation width of %0d-bits exceeds maximum of 8-bits", ACTIVATION_WIDTH); + $finish; + end + end + + localparam int unsigned SF = MW / SIMD; + localparam int unsigned NF = MH / PE; + localparam int unsigned N_TRS_OP = SF * NF * N_VECTORS; + localparam int unsigned N_TRS_EP = (GEMM_TYPE == "mmau_1d") ? CLEN-1 + CLEN-1 + DSP_STAGES + 2 : + CLEN-1 + CLEN-1 + DSP_STAGES + PE; + +// Input replay +// --------------------------------------------------------------------- + logic [SIMD-1:0][ACTIVATION_WIDTH-1:0] adat_s0; + logic [ASIMD-1:0][ACTIVATION_WIDTH-1:0] adat_s0_wd; + logic alast_s0; + logic avld_s0, ardy_s0; + + logic [SIMD-1:0][ACTIVATION_WIDTH-1:0] act_s0_tdata; + logic [ASIMD-1:0][ACTIVATION_WIDTH-1:0] act_s0_tdata_mod; + logic act_s0_tlast; + logic act_s0_tvalid, act_s0_tready; + + uwire [2:0] act_done; + input_gen #( + .DATA_WIDTH(SIMD*ACTIVATION_WIDTH), + .FM_SIZE(SF * CLEN), + .D(3), + .DIMS('{NF, SF, CLEN}), + .COEFS(IN_TILED ? '{0, CLEN, 1} : '{0, 1, SF}) + ) activation_replay ( + .clk(ap_clk), .rst(~ap_rst_n), + .idat(s_axis_input_tdata), + .ivld(s_axis_input_tvalid), .irdy(s_axis_input_tready), + .odat(act_s0_tdata), .ovld(act_s0_tvalid), .olst(), .odone(act_done), .ordy(act_s0_tready) + ); + assign act_s0_tlast = act_done[1]; + + if (ASIMD > SIMD) + assign act_s0_tdata_mod[ASIMD-1:SIMD] = '0; + assign act_s0_tdata_mod[SIMD-1:0] = act_s0_tdata[SIMD-1:0]; + +// Activation scheduling +// --------------------------------------------------------------------- + logic [ASIMD-1:0][ACTIVATION_WIDTH-1:0] act_s1_tdata; + logic [CLEN-1:0] act_s1_tlast; + logic act_s1_tvalid, act_s1_tready; + + sched_activations #( + .CU_SIMD(CU_SIMD), .CLEN(CLEN), + .ACTIVATION_WIDTH(ACTIVATION_WIDTH), + .N_BEATS_OP(N_TRS_OP), .N_BEATS_EP(N_TRS_EP) + ) inst_sched_act ( + .clk(ap_clk), .rst(~ap_rst_n), + .s_axis_tdata(act_s0_tdata_mod), .s_axis_tvalid(act_s0_tvalid), .s_axis_tready(act_s0_tready), .s_axis_tlast(act_s0_tlast), + .m_axis_tdata(act_s1_tdata), .m_axis_tvalid(act_s1_tvalid), .m_axis_tready(act_s1_tready), .m_axis_tlast(act_s1_tlast) + ); + +// Weight scheduling +// --------------------------------------------------------------------- + logic [WSIMD-1:0][WEIGHT_WIDTH-1:0] wgt_s1_tdata; + logic wgt_s1_tvalid, wgt_s1_tready; + +if(GEMM_TYPE == "mmau_1d") begin + sched_weights_1d #( + .CU_SIMD(CU_SIMD), .PE(PE), + .WEIGHT_WIDTH(WEIGHT_WIDTH), + .N_BEATS_OP(N_TRS_OP), .N_BEATS_EP(N_TRS_EP) + ) inst_sched_wgt ( + .clk(ap_clk), .rst(~ap_rst_n), + .s_axis_tdata(s_axis_weights_tdata), .s_axis_tvalid(s_axis_weights_tvalid), .s_axis_tready(s_axis_weights_tready), + .m_axis_tdata(wgt_s1_tdata), .m_axis_tvalid(wgt_s1_tvalid), .m_axis_tready(wgt_s1_tready) + ); +end else begin + sched_weights_2d #( + .CU_SIMD(CU_SIMD), .PE(PE), + .WEIGHT_WIDTH(WEIGHT_WIDTH), + .N_BEATS_OP(N_TRS_OP), .N_BEATS_EP(N_TRS_EP) + ) inst_sched_wgt ( + .clk(ap_clk), .rst(~ap_rst_n), + .s_axis_tdata(s_axis_weights_tdata), .s_axis_tvalid(s_axis_weights_tvalid), .s_axis_tready(s_axis_weights_tready), + .m_axis_tdata(wgt_s1_tdata), .m_axis_tvalid(wgt_s1_tvalid), .m_axis_tready(wgt_s1_tready) + ); +end + + +// Global enable +// --------------------------------------------------------------------- + logic en; + logic [ASIMD-1:0][ACTIVATION_WIDTH-1:0] act_s2_tdata; + logic [CLEN-1:0] act_s2_tlast; + logic [WSIMD-1:0][WEIGHT_WIDTH-1:0] wgt_s2_tdata; + logic s2_tvalid; + + en_global #( + .PE(PE), .CLEN(CLEN), .CU_SIMD(CU_SIMD), + .WEIGHT_WIDTH(WEIGHT_WIDTH), .ACTIVATION_WIDTH(ACTIVATION_WIDTH) + ) inst_en_global ( + .clk(ap_clk), .rst(~ap_rst_n), + .en(en), + .s_act_tvalid(act_s1_tvalid), .s_act_tready(act_s1_tready), .s_act_tdata(act_s1_tdata), .s_act_tlast(act_s1_tlast), + .m_act_tdata(act_s2_tdata), .m_act_tlast(act_s2_tlast), + .s_wgt_tvalid(wgt_s1_tvalid), .s_wgt_tready(wgt_s1_tready), .s_wgt_tdata(wgt_s1_tdata), + .m_wgt_tdata(wgt_s2_tdata), + .m_tvalid(s2_tvalid) + ); + +// CU +// --------------------------------------------------------------------- + logic p_tvalid, p_tready; + logic [PE-1:0][ACCU_WIDTH-1:0] p_tdata; + +if(GEMM_TYPE == "mmau_1d") begin + cu_mmau_1d #( + .PE(PE), .CLEN(CLEN), .CU_SIMD(CU_SIMD), + .ACTIVATION_WIDTH(ACTIVATION_WIDTH), .WEIGHT_WIDTH(WEIGHT_WIDTH), .ACCU_WIDTH(ACCU_WIDTH), + .FORCE_BEHAVIOURAL(FORCE_BEHAVIOURAL) + ) inst_cu_mmau ( + .clk(ap_clk), .rst(~ap_rst_n), + .en(en), + .ivld(s2_tvalid), .a(act_s2_tdata), .ilast(act_s2_tlast), .w(wgt_s2_tdata), + .m_axis_tvalid(p_tvalid), .m_axis_tready(p_tready), .m_axis_tdata(p_tdata) + ); +end else begin + cu_mmau_2d #( + .PE(PE), .CLEN(CLEN), .CU_SIMD(CU_SIMD), + .ACTIVATION_WIDTH(ACTIVATION_WIDTH), .WEIGHT_WIDTH(WEIGHT_WIDTH), .ACCU_WIDTH(ACCU_WIDTH), + .FORCE_BEHAVIOURAL(FORCE_BEHAVIOURAL) + ) inst_cu_mmau ( + .clk(ap_clk), .rst(~ap_rst_n), + .en(en), + .ivld(s2_tvalid), .a(act_s2_tdata), .ilast(act_s2_tlast), .w(wgt_s2_tdata), + .m_axis_tvalid(p_tvalid), .m_axis_tready(p_tready), .m_axis_tdata(p_tdata) + ); +end + + +// Reorder +// --------------------------------------------------------------------- + if(OUT_TILED == 0) begin + input_gen #( + .DATA_WIDTH(OUTPUT_STREAM_WIDTH_BA), + .FM_SIZE(NF * CLEN), + .D(2), + .DIMS('{CLEN, NF}), + .COEFS('{1, CLEN}) + ) inst_reorder_out ( + .clk(ap_clk), .rst(~ap_rst_n), + .idat(p_tdata), + .ivld(p_tvalid), .irdy(p_tready), + .odat(m_axis_output_tdata), .ovld(m_axis_output_tvalid), + .olst(), .odone(), .ordy(m_axis_output_tready) + ); + end else begin + assign m_axis_output_tvalid = p_tvalid; + assign p_tready = m_axis_output_tready; + assign m_axis_output_tdata = p_tdata; + end + +endmodule diff --git a/finn-rtllib/mvu_tiled/mmu/mmu_axi_wrapper.v b/finn-rtllib/mvu_tiled/mmu/mmu_axi_wrapper.v new file mode 100644 index 0000000000..70c75193ba --- /dev/null +++ b/finn-rtllib/mvu_tiled/mmu/mmu_axi_wrapper.v @@ -0,0 +1,100 @@ +/****************************************************************************** + * Copyright (C) 2024, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + *****************************************************************************/ + +module $MODULE_NAME_AXI_WRAPPER$ #( + parameter PE = $PE$, + parameter SIMD = $SIMD$, + parameter CU_SIMD = 3, + parameter ACTIVATION_WIDTH = $ACTIVATION_WIDTH$, + parameter WEIGHT_WIDTH = $WEIGHT_WIDTH$, + parameter ACCU_WIDTH = $ACCU_WIDTH$, + parameter MW = $MW$, + parameter MH = $MH$, + parameter N_VECTORS = $N_VECTORS$, + parameter SIGNED_ACTIVATIONS = $SIGNED_ACTIVATIONS$, + parameter PUMPED_COMPUTE = $PUMPED_COMPUTE$, + + // Safely deducible parameters + parameter WSIMD = PE * CU_SIMD, + parameter WEIGHT_STREAM_WIDTH_BA = (WSIMD * WEIGHT_WIDTH + 7)/8 * 8, + parameter INPUT_STREAM_WIDTH_BA = (SIMD * ACTIVATION_WIDTH + 7) / 8 * 8, + parameter OUTPUT_STREAM_WIDTH_BA = (PE * ACCU_WIDTH + 7)/8 * 8 +)( + // Global Control + (* X_INTERFACE_PARAMETER = "ASSOCIATED_BUSIF in1_V:in0_V:out0_V, ASSOCIATED_RESET ap_rst_n" *) + (* X_INTERFACE_INFO = "xilinx.com:signal:clock:1.0 ap_clk CLK" *) + input ap_clk, + (* X_INTERFACE_PARAMETER = "ASSOCIATED_RESET ap_rst_n" *) + (* X_INTERFACE_INFO = "xilinx.com:signal:clock:1.0 ap_clk2x CLK" *) + input ap_clk2x, + (* X_INTERFACE_PARAMETER = "POLARITY ACTIVE_LOW" *) + input ap_rst_n, + + // Weight Stream + input [WEIGHT_STREAM_WIDTH_BA-1:0] in1_V_TDATA, + input in1_V_TVALID, + output in1_V_TREADY, + // Input Stream + input [INPUT_STREAM_WIDTH_BA-1:0] in0_V_TDATA, + input in0_V_TVALID, + output in0_V_TREADY, + // Output Stream + output [OUTPUT_STREAM_WIDTH_BA-1:0] out0_V_TDATA, + output out0_V_TVALID, + input out0_V_TREADY +); + +// NOTE: MW and MH are swapped -- FINN convention (MW=input features, MH=output features) +// is opposite to the MMU RTL convention. +mmu_axi #( + .GEMM_TYPE("$GEMM_TYPE$"), + .PE(PE), .SIMD(SIMD), .CU_SIMD(CU_SIMD), + .ACTIVATION_WIDTH(ACTIVATION_WIDTH), .WEIGHT_WIDTH(WEIGHT_WIDTH), .ACCU_WIDTH(ACCU_WIDTH), + .MW(MW), .MH(MH), .N_VECTORS(N_VECTORS), + .SIGNED_ACTIVATIONS(SIGNED_ACTIVATIONS), .PUMPED_COMPUTE(PUMPED_COMPUTE), + .FORCE_BEHAVIOURAL(0) +) inst ( + .ap_clk(ap_clk), + .ap_clk2x(ap_clk2x), + .ap_rst_n(ap_rst_n), + .s_axis_weights_tdata(in1_V_TDATA), + .s_axis_weights_tvalid(in1_V_TVALID), + .s_axis_weights_tready(in1_V_TREADY), + .s_axis_input_tdata(in0_V_TDATA), + .s_axis_input_tvalid(in0_V_TVALID), + .s_axis_input_tready(in0_V_TREADY), + .m_axis_output_tdata(out0_V_TDATA), + .m_axis_output_tvalid(out0_V_TVALID), + .m_axis_output_tready(out0_V_TREADY) +); + +endmodule // $MODULE_NAME_AXI_WRAPPER$ diff --git a/finn-rtllib/mvu_tiled/mmu/q_writer.sv b/finn-rtllib/mvu_tiled/mmu/q_writer.sv new file mode 100644 index 0000000000..2a796c0f8d --- /dev/null +++ b/finn-rtllib/mvu_tiled/mmu/q_writer.sv @@ -0,0 +1,124 @@ +/****************************************************************************** + * Copyright (C) 2024, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + *****************************************************************************/ + +module q_writer #( + int unsigned CU_SIMD, + int unsigned CLEN, + int unsigned ACTIVATION_WIDTH +)( + // Global Control + input logic clk, + input logic rst, + + // Input Stream + input logic [CLEN-1:0][CU_SIMD-1:0][ACTIVATION_WIDTH-1:0] s_axis_tdata, + input logic s_axis_tlast, + input logic s_axis_tvalid, + output logic s_axis_tready, + + // Output Stream + output logic [CU_SIMD-1:0][ACTIVATION_WIDTH-1:0] m_axis_tdata, + output logic m_axis_tlast, + output logic m_axis_tvalid, + input logic m_axis_tready +); + +// Params +// --------------------------------------------------------------------- + localparam integer CLEN_BITS = (CLEN == 1) ? 1 : $clog2(CLEN); + +// Skid +// --------------------------------------------------------------------- + logic axis_s0_tvalid, axis_s0_tready; + logic axis_s0_tlast; + logic [CLEN-1:0][CU_SIMD-1:0][ACTIVATION_WIDTH-1:0] axis_s0_tdata; + + skid #(.DATA_WIDTH(CLEN*CU_SIMD*ACTIVATION_WIDTH + 1), .FEED_STAGES(1)) inst_reg ( + .clk(clk), .rst(rst), + .idat({s_axis_tlast, s_axis_tdata}), .ivld(s_axis_tvalid), .irdy(s_axis_tready), + .odat({axis_s0_tlast, axis_s0_tdata}), .ovld(axis_s0_tvalid), .ordy(axis_s0_tready) + ); + +// PtoS +// --------------------------------------------------------------------- + logic [CLEN_BITS-1:0] wr_ptr_C = '0, wr_ptr_N; + + logic axis_s1_tvalid, axis_s1_tready; + logic axis_s1_tlast; + logic [CU_SIMD-1:0][ACTIVATION_WIDTH-1:0] axis_s1_tdata; + + // REG + always_ff @(posedge clk) begin + if(rst) begin + wr_ptr_C <= 0; + end else begin + wr_ptr_C <= wr_ptr_N; + end + end + + // DP + always_comb begin + wr_ptr_N = wr_ptr_C; + + axis_s0_tready = 1'b0; + axis_s1_tvalid = 1'b0; + axis_s1_tdata = axis_s0_tdata[wr_ptr_C]; + axis_s1_tlast = 1'b0; + + if(axis_s0_tvalid) begin + axis_s1_tvalid = 1'b1; + + if(axis_s1_tready) begin + if(wr_ptr_C == CLEN-1) begin + wr_ptr_N = 0; + axis_s0_tready = 1'b1; + axis_s1_tlast = axis_s0_tlast; + end else begin + wr_ptr_N = wr_ptr_C + 1; + end + end + end + end + +// Queue +// --------------------------------------------------------------------- + Q_srl #( + .depth(CLEN), + .width(CU_SIMD*ACTIVATION_WIDTH+1) + ) inst_queue ( + .clock(clk), .reset(rst), + .count(), .maxcount(), + .i_v(axis_s1_tvalid), .i_r(axis_s1_tready), .i_d({axis_s1_tlast, axis_s1_tdata}), + .o_v(m_axis_tvalid), .o_r(m_axis_tready), .o_d({m_axis_tlast, m_axis_tdata}) + ); + +endmodule diff --git a/finn-rtllib/mvu_tiled/mmu/sched_activations.sv b/finn-rtllib/mvu_tiled/mmu/sched_activations.sv new file mode 100644 index 0000000000..29f832fcc1 --- /dev/null +++ b/finn-rtllib/mvu_tiled/mmu/sched_activations.sv @@ -0,0 +1,199 @@ +/****************************************************************************** + * Copyright (C) 2024, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + *****************************************************************************/ + +module sched_activations #( + int unsigned CU_SIMD, + int unsigned CLEN, + int unsigned ACTIVATION_WIDTH, + + int unsigned N_BEATS_OP, + int unsigned N_BEATS_EP, + + int unsigned N_DCPL_STAGES = 2 +)( + // Global Control + input logic clk, + input logic rst, + + // Input Stream + input logic [CLEN-1:0][CU_SIMD-1:0][ACTIVATION_WIDTH-1:0] s_axis_tdata, + input logic s_axis_tlast, + input logic s_axis_tvalid, + output logic s_axis_tready, + + // Output Stream + output logic [CLEN-1:0][CU_SIMD-1:0][ACTIVATION_WIDTH-1:0] m_axis_tdata, + output logic [CLEN-1:0] m_axis_tlast, + output logic m_axis_tvalid, + input logic m_axis_tready +); + +// Params +// --------------------------------------------------------------------- + localparam integer CLEN_BITS = (CLEN == 1) ? 1 : $clog2(CLEN); + localparam integer CNT_EPLG_BITS = (N_BEATS_OP > N_BEATS_EP) ? + (N_BEATS_OP == 1) ? 1 : $clog2(N_BEATS_OP) : + (N_BEATS_EP == 1) ? 1 : $clog2(N_BEATS_EP); + +// Shifting +// --------------------------------------------------------------------- + logic [CLEN_BITS-1:0] wr_ptr_C = '0, wr_ptr_N; + + logic [CLEN-1:0] valid_C = '0, valid_N; + logic eplg_C = '0, eplg_N; + logic [CNT_EPLG_BITS-1:0] cnt_eplg_C = '0, cnt_eplg_N; + + logic [CLEN-1:0] q_in_tvalid, q_in_tready; + logic [CLEN-1:0] q_in_tlast; + logic [CLEN-1:0][CLEN-1:0][CU_SIMD-1:0][ACTIVATION_WIDTH-1:0] q_in_tdata; + + logic [CLEN-1:0] q_out_tvalid, q_out_tready; + logic [CLEN-1:0] q_out_tlast; + logic [CLEN-1:0][CU_SIMD-1:0][ACTIVATION_WIDTH-1:0] q_out_tdata; + + logic ovld, ordy; + logic [CLEN-1:0][CU_SIMD-1:0][ACTIVATION_WIDTH-1:0] odat; + logic [CLEN-1:0] olast; + + + for(genvar i = 0; i < CLEN; i++) begin + q_writer #( + .CLEN(CLEN), .CU_SIMD(CU_SIMD), .ACTIVATION_WIDTH(ACTIVATION_WIDTH) + ) inst_queue_writer ( + .clk(clk), .rst(rst), + .s_axis_tvalid(q_in_tvalid[i]), .s_axis_tready(q_in_tready[i]), .s_axis_tdata(q_in_tdata[i]), .s_axis_tlast(q_in_tlast[i]), + .m_axis_tvalid(q_out_tvalid[i]), .m_axis_tready(q_out_tready[i]), .m_axis_tdata(q_out_tdata[i]), .m_axis_tlast(q_out_tlast[i]) + ); + end + + // REG + always_ff @(posedge clk) begin + if(rst) begin + wr_ptr_C <= '0; + valid_C[0] <= 1'b1; + valid_C[CLEN-1:1] <= '0; + eplg_C <= '0; + cnt_eplg_C <= '0; + end else begin + wr_ptr_C <= wr_ptr_N; + valid_C <= valid_N; + eplg_C <= eplg_N; + cnt_eplg_C <= cnt_eplg_N; + end + end + + // DP + always_comb begin + // Read + valid_N = valid_C; + eplg_N = eplg_C; + cnt_eplg_N = cnt_eplg_C; + + q_out_tready = '0; + + // Read + if (ovld && ordy) begin + // Read from queue + for(int i = 0; i < CLEN; i++) begin + q_out_tready[i] = valid_C[i]; + end + + // Shift ctrl + valid_N[CLEN-1:1] = valid_C[CLEN-2:0]; + if(eplg_C) begin + if(cnt_eplg_C == N_BEATS_EP-1) begin + eplg_N = 1'b0; + cnt_eplg_N = 0; + valid_N[0] = 1'b1; + end else begin + cnt_eplg_N = cnt_eplg_C + 1; + valid_N[0] = 1'b0; + end + end else begin + if(cnt_eplg_C == N_BEATS_OP-1) begin + eplg_N = 1'b1; + cnt_eplg_N = 0; + valid_N[0] = 1'b0; + end else begin + cnt_eplg_N = cnt_eplg_C + 1; + valid_N[0] = 1'b1; + end + end + end + + // Write + wr_ptr_N = wr_ptr_C; + + s_axis_tready = 1'b0; + q_in_tvalid = '0; + for(int i = 0; i < CLEN; i++) begin + q_in_tdata[i] = s_axis_tdata; + q_in_tlast[i] = s_axis_tlast; + end + + if(s_axis_tvalid) begin + q_in_tvalid[wr_ptr_C] = 1'b1; + + if(q_in_tready[wr_ptr_C]) begin + s_axis_tready = 1'b1; + wr_ptr_N = (wr_ptr_C == CLEN-1) ? 0 : wr_ptr_C + 1; + end + end + + end + + // Output valid + always_comb begin + ovld = 1'b1; + + for(int i = 0; i < CLEN; i++) begin + if((valid_C[i] & q_out_tvalid[i]) != valid_C[i]) begin + ovld = 1'b0; + end + end + end + + for(genvar i = 0; i < CLEN; i++) begin + assign odat[i] = valid_C[i] ? q_out_tdata[i] : '0; + assign olast[i] = valid_C[i] ? q_out_tlast[i] : '0; + end + +// Oreg +// --------------------------------------------------------------------- + skid #(.DATA_WIDTH(CLEN*(CU_SIMD*ACTIVATION_WIDTH + 1)), .FEED_STAGES(N_DCPL_STAGES)) inst_oreg ( + .clk(clk), .rst(rst), + .idat({olast, odat}), .ivld(ovld), .irdy(ordy), + .odat({m_axis_tlast, m_axis_tdata}), .ovld(m_axis_tvalid), .ordy(m_axis_tready) + ); + + +endmodule diff --git a/finn-rtllib/mvu_tiled/mvu_tiled_axi.sv b/finn-rtllib/mvu_tiled/mvu_tiled_axi.sv new file mode 100644 index 0000000000..558abc4a4a --- /dev/null +++ b/finn-rtllib/mvu_tiled/mvu_tiled_axi.sv @@ -0,0 +1,290 @@ +/****************************************************************************** + * Copyright (C) 2024, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * @brief Matrix Vector Unit with Tiling (MVU-Tiled) AXI-Stream wrapper. + * @details + * The following compute cores are supported: + * - [4,9]-bit MVU on DSP58 achieving 3 MACs/DSP, + * Folding hints: + * - PE scaling should divide MH. + * - SIMD scaling should divide MW. + * - TH scaling should divide MH_OUTER + * - WSIMD * TH <= PE * SIMD + * - Otherwise, keep SIMD and PE somewhat balanced. SIMD scaling tends to + * impact critical paths more than PE scaling. PE scaling implies a + * bigger fanout on the input activations. + * - Full unfolding along MH (PE=MH) results in no replay buffer instantiated + *****************************************************************************/ + +module mvu_tiled_axi #( + int unsigned PE, + int unsigned SIMD, + + int unsigned WEIGHT_WIDTH, + int unsigned ACTIVATION_WIDTH, + int unsigned ACCU_WIDTH, + + int unsigned MW, + int unsigned MH, + int unsigned TH, + + int unsigned IN_TILED = 0, + int unsigned OUT_TILED = 0, + + bit NARROW_WEIGHTS = 0, // unused — reserved for future narrow-weight support + bit SIGNED_ACTIVATIONS = 0, + bit PUMPED_COMPUTE = 0, // Not meaningful for SIMD < 2, which will error out. + bit FORCE_BEHAVIORAL = 0, // unused — reserved for future behavioral fallback + bit M_REG_LUT = 1, // unused — reserved for future LUT-based M register + + parameter COMPUTE_CORE = "mvu_vvu_8sx9_dsp58", + int unsigned N_DCPL_STAGES = 2, + + // Safely deducible parameters + localparam int unsigned WSIMD = (PE * SIMD) / TH, + localparam int unsigned WEIGHT_STREAM_WIDTH = WSIMD * WEIGHT_WIDTH, + localparam int unsigned WEIGHT_STREAM_WIDTH_BA = (WEIGHT_STREAM_WIDTH + 7)/8 * 8, + localparam int unsigned INPUT_STREAM_WIDTH = SIMD * ACTIVATION_WIDTH, + localparam int unsigned INPUT_STREAM_WIDTH_BA = (INPUT_STREAM_WIDTH + 7)/8 * 8, + localparam int unsigned OUTPUT_STREAM_WIDTH = PE * ACCU_WIDTH, + localparam int unsigned OUTPUT_STREAM_WIDTH_BA = (OUTPUT_STREAM_WIDTH + 7)/8 * 8, + localparam bit SIMD_UNEVEN = SIMD % 2 +)( + // Global Control + input logic ap_clk, + input logic ap_clk2x, // synchronous, double-speed clock; only used for PUMPED_COMPUTE + input logic ap_rst_n, + + // Weight Stream + input logic [WEIGHT_STREAM_WIDTH_BA-1:0] s_axis_weights_tdata, + input logic s_axis_weights_tvalid, + output logic s_axis_weights_tready, + + // Input Stream + input logic [INPUT_STREAM_WIDTH_BA-1:0] s_axis_input_tdata, + input logic s_axis_input_tvalid, + output logic s_axis_input_tready, + + // Output Stream + output logic [OUTPUT_STREAM_WIDTH_BA-1:0] m_axis_output_tdata, + output logic m_axis_output_tvalid, + input logic m_axis_output_tready +); + + //=== Parameter Validation ============================================== + initial begin + if(MW % SIMD != 0) begin + $error("%m: Matrix width (%0d) is not a multiple of SIMD (%0d).", MW, SIMD); + $finish; + end + if(MH % PE != 0) begin + $error("%m: Matrix height (%0d) is not a multiple of PE (%0d).", MH, PE); + $finish; + end + if((PE * SIMD) % TH != 0) begin + $error("%m: Tile (%0d) is not a multiple of TH (%0d).", (PE*SIMD), TH); + $finish; + end + if(PUMPED_COMPUTE && (SIMD == 1)) begin + $error("Clock pumping an input of SIMD=1 is not meaningful."); + $finish; + end + if(WEIGHT_WIDTH > 8) begin + $error("Weight width of %0d-bits exceeds maximum of 8-bits", WEIGHT_WIDTH); + $finish; + end + if(ACTIVATION_WIDTH > 8) begin + $error("Activation width of %0d-bits exceeds maximum of 8-bits", ACTIVATION_WIDTH); + $finish; + end + end + + uwire rst = !ap_rst_n; + + //=== Activation Replay ================================================= + typedef logic [SIMD-1:0][ACTIVATION_WIDTH-1:0] mvu_flatin_t; + uwire mvu_flatin_t amvau; + uwire alast; + uwire avld; + uwire ardy; + + localparam int unsigned SF = MW / SIMD; + localparam int unsigned NF = MH / PE; + + uwire [2:0] act_done; + input_gen #( + .DATA_WIDTH($bits(mvu_flatin_t)), + .FM_SIZE(SF * TH), + .D(3), + .DIMS('{NF, SF, TH}), + .COEFS('{0, 1, SF}) + ) activation_replay ( + .clk(ap_clk), .rst(rst), + .idat(mvu_flatin_t'(s_axis_input_tdata)), + .ivld(s_axis_input_tvalid), .irdy(s_axis_input_tready), + .odat(amvau), .ovld(avld), .olst(), .odone(act_done), .ordy(ardy) + ); + assign alast = act_done[1]; + + //=== Weight Buffering ================================================== + typedef logic [PE-1:0][SIMD-1:0][WEIGHT_WIDTH-1:0] mvu_w_t; + uwire mvu_w_t wdat; + uwire wvld; + uwire wrdy; + + weights_buff_tile #( + .WEIGHT_WIDTH(WEIGHT_WIDTH), + .SIMD(SIMD), .PE(PE), + .TH(TH), .WSIMD(WSIMD), + .N_DCPL_STAGES(N_DCPL_STAGES) + ) inst_weights_buff_tile ( + .clk(ap_clk), .rst(rst), + .ivld(s_axis_weights_tvalid), .irdy(s_axis_weights_tready), .idat(s_axis_weights_tdata), + .ovld(wvld), .ordy(wrdy), .odat(wdat) + ); + + //=== Flow Control ====================================================== + uwire en; + uwire istb = avld && wvld; + assign ardy = en && wvld; + assign wrdy = en && avld; + + //=== DSP Compute ======================================================= + typedef logic [PE-1:0][ACCU_WIDTH-1:0] dsp_p_t; + uwire ovld; + uwire dsp_p_t odat; + if(1) begin : blkDsp + localparam int unsigned EFFECTIVE_SIMD = SIMD_UNEVEN && PUMPED_COMPUTE? SIMD+1 : SIMD; + localparam int unsigned DSP_SIMD = EFFECTIVE_SIMD / (PUMPED_COMPUTE+1); + typedef logic [PE -1:0][DSP_SIMD-1:0][WEIGHT_WIDTH -1:0] dsp_w_t; + typedef logic [DSP_SIMD-1:0][ACTIVATION_WIDTH-1:0] dsp_a_t; + + uwire dsp_last; + uwire dsp_w_t dsp_w; + uwire dsp_a_t dsp_a; + + uwire dsp_vld; + uwire dsp_p_t dsp_p; + + // TODO: No double-pumping in the initial implementation + uwire dsp_en = en; + + assign dsp_last = alast && istb; + assign dsp_w = wdat; + assign dsp_a = amvau; + + assign ovld = dsp_vld; + assign odat = dsp_p; + + case(COMPUTE_CORE) + "mvu_vvu_8sx9_dsp58": begin : core + cu_mvau_tiled #( + .PE(PE), .SIMD(SIMD), + .TH(TH), + .WEIGHT_WIDTH(WEIGHT_WIDTH), .ACTIVATION_WIDTH(ACTIVATION_WIDTH), .ACCU_WIDTH(ACCU_WIDTH), + .SIGNED_ACTIVATIONS(SIGNED_ACTIVATIONS) + ) inst_cu_mvau_tiled ( + .clk(ap_clk), .rst(rst), .en(dsp_en), + .ivld(istb), .ilast(dsp_last), .w(dsp_w), .a(dsp_a), + .ovld(dsp_vld), .p(dsp_p) + ); + end + default: initial begin + $error("Unrecognized COMPUTE_CORE '%s'", COMPUTE_CORE); + $finish; + end + endcase + + end : blkDsp + + //=== Output Register Slice ============================================= + // Make `en` computation independent from external inputs. + // Drive all outputs from registers. + + logic MIntVld; + uwire m_int_rdy; + logic [OUTPUT_STREAM_WIDTH_BA-1:0] MIntDat; + + struct packed { + logic rdy; + logic [PE-1:0][ACCU_WIDTH-1:0] dat; + } A = '{ rdy: 1, default: 'x }; // side-step register used when encountering backpressure + struct packed { + logic vld; + logic [PE-1:0][ACCU_WIDTH-1:0] dat; + } B = '{ vld: 0, default: 'x }; // ultimate output register + + assign en = A.rdy; + uwire b_load = !B.vld || m_int_rdy; + + always_ff @(posedge ap_clk) begin + if(rst) begin + A <= '{ rdy: 1, default: 'x }; + B <= '{ vld: 0, default: 'x }; + end + else begin + if(A.rdy) A.dat <= odat; + A.rdy <= (A.rdy && !ovld) || b_load; + + if(b_load) begin + B <= '{ + vld: ovld || !A.rdy, + dat: A.rdy? odat : A.dat + }; + end + end + end + assign MIntVld = B.vld; + assign MIntDat = { {(OUTPUT_STREAM_WIDTH_BA-OUTPUT_STREAM_WIDTH){B.dat[PE-1][ACCU_WIDTH-1]}}, B.dat }; + + //=== Output Reordering ================================================= + + if(OUT_TILED == 0) begin : genReorder + input_gen #( + .DATA_WIDTH(OUTPUT_STREAM_WIDTH_BA), + .FM_SIZE(NF * TH), + .D(2), + .DIMS('{TH, NF}), + .COEFS('{1, TH}) + ) inst_reorder_out ( + .clk(ap_clk), .rst(rst), + .idat(MIntDat), + .ivld(MIntVld), .irdy(m_int_rdy), + .odat(m_axis_output_tdata), .ovld(m_axis_output_tvalid), + .olst(), .odone(), .ordy(m_axis_output_tready) + ); + end : genReorder + else begin : genPassthru + assign m_axis_output_tvalid = MIntVld; + assign m_int_rdy = m_axis_output_tready; + assign m_axis_output_tdata = MIntDat; + end : genPassthru + +endmodule : mvu_tiled_axi diff --git a/finn-rtllib/mvu_tiled/mvu_tiled_axi_wrapper.v b/finn-rtllib/mvu_tiled/mvu_tiled_axi_wrapper.v new file mode 100644 index 0000000000..36ce60c3b6 --- /dev/null +++ b/finn-rtllib/mvu_tiled/mvu_tiled_axi_wrapper.v @@ -0,0 +1,97 @@ +/****************************************************************************** + * Copyright (C) 2024, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + *****************************************************************************/ + +module $MODULE_NAME_AXI_WRAPPER$ #( + parameter PE = $PE$, + parameter SIMD = $SIMD$, + parameter ACTIVATION_WIDTH = $ACTIVATION_WIDTH$, + parameter WEIGHT_WIDTH = $WEIGHT_WIDTH$, + parameter ACCU_WIDTH = $ACCU_WIDTH$, + parameter MW = $MW$, + parameter MH = $MH$, + parameter TH = $TH$, + parameter NARROW_WEIGHTS = $NARROW_WEIGHTS$, + parameter SIGNED_ACTIVATIONS = $SIGNED_ACTIVATIONS$, + parameter PUMPED_COMPUTE = $PUMPED_COMPUTE$, + + // Safely deducible parameters + parameter WSIMD = (PE * SIMD) / TH, + parameter WEIGHT_STREAM_WIDTH_BA = (WSIMD * WEIGHT_WIDTH + 7)/8 * 8, + parameter INPUT_STREAM_WIDTH_BA = (SIMD * ACTIVATION_WIDTH + 7) / 8 * 8, + parameter OUTPUT_STREAM_WIDTH_BA = (PE * ACCU_WIDTH + 7)/8 * 8 +)( + // Global Control + (* X_INTERFACE_PARAMETER = "ASSOCIATED_BUSIF in1_V:in0_V:out0_V, ASSOCIATED_RESET ap_rst_n" *) + (* X_INTERFACE_INFO = "xilinx.com:signal:clock:1.0 ap_clk CLK" *) + input ap_clk, + (* X_INTERFACE_PARAMETER = "ASSOCIATED_RESET ap_rst_n" *) + (* X_INTERFACE_INFO = "xilinx.com:signal:clock:1.0 ap_clk2x CLK" *) + input ap_clk2x, + (* X_INTERFACE_PARAMETER = "POLARITY ACTIVE_LOW" *) + input ap_rst_n, + + // Weight Stream + input [WEIGHT_STREAM_WIDTH_BA-1:0] in1_V_TDATA, + input in1_V_TVALID, + output in1_V_TREADY, + // Input Stream + input [INPUT_STREAM_WIDTH_BA-1:0] in0_V_TDATA, + input in0_V_TVALID, + output in0_V_TREADY, + // Output Stream + output [OUTPUT_STREAM_WIDTH_BA-1:0] out0_V_TDATA, + output out0_V_TVALID, + input out0_V_TREADY +); + +mvu_tiled_axi #( + .PE(PE), .SIMD(SIMD), + .ACTIVATION_WIDTH(ACTIVATION_WIDTH), .WEIGHT_WIDTH(WEIGHT_WIDTH), .ACCU_WIDTH(ACCU_WIDTH), + .MW(MW), .MH(MH), .TH(TH), + .NARROW_WEIGHTS(NARROW_WEIGHTS), .SIGNED_ACTIVATIONS(SIGNED_ACTIVATIONS), .PUMPED_COMPUTE(PUMPED_COMPUTE), + .FORCE_BEHAVIORAL(0) + ) inst ( + .ap_clk(ap_clk), + .ap_clk2x(ap_clk2x), + .ap_rst_n(ap_rst_n), + .s_axis_weights_tdata(in1_V_TDATA), + .s_axis_weights_tvalid(in1_V_TVALID), + .s_axis_weights_tready(in1_V_TREADY), + .s_axis_input_tdata(in0_V_TDATA), + .s_axis_input_tvalid(in0_V_TVALID), + .s_axis_input_tready(in0_V_TREADY), + .m_axis_output_tdata(out0_V_TDATA), + .m_axis_output_tvalid(out0_V_TVALID), + .m_axis_output_tready(out0_V_TREADY) +); + +endmodule // $MODULE_NAME_AXI_WRAPPER$ diff --git a/finn-rtllib/mvu_tiled/tb/mvu_tiled_axi_tb.sv b/finn-rtllib/mvu_tiled/tb/mvu_tiled_axi_tb.sv new file mode 100644 index 0000000000..a45fa7db50 --- /dev/null +++ b/finn-rtllib/mvu_tiled/tb/mvu_tiled_axi_tb.sv @@ -0,0 +1,334 @@ +/****************************************************************************** + * Copyright (C) 2024, Advanced Micro Devices, Inc. + * All rights reserved. + * + * SPDX-License-Identifier: BSD-3-Clause + * + * @brief Testbench for MVU-Tiled AXI wrapper module. + * @details + * Adapted from mvu/tb/mvu_axi_tb.sv for the tiled architecture. + * Exercises mvu_tiled_axi with multiple parameter configurations in parallel. + * + * Data flow under test: + * activations -> input_gen (replays MH/PE times) + * weights -> weights_buff_tile (collects NW=TH words, replays TH times) + * compute -> cu_mvau_tiled (DSP58 INT8, 3 MACs/DSP) + * accumulate -> acc_stage (pipelined add_multi tree + circular FIFO) + * reorder -> input_gen (transpose tiled -> sequential NF order) + * + * Weight feed order: + * For each neuron fold (h), for each SIMD fold (w): + * send TH chunks of WSIMD weights (PE*SIMD total per tile). + * + * Activation feed order: + * For each TH tile (y), for each SIMD fold (x): + * send one SIMD-wide activation word. + * Replay buffer handles repetition across neuron folds. + * + * Output order (after reorder_out): + * Sequential neuron folds: nf=0..MH/PE-1, one PE-wide word per fold, + * repeated for each input vector. + *****************************************************************************/ + +module mvu_tiled_axi_tb; + + // Test Configurations + localparam int unsigned ROUNDS = 7; + + typedef struct { + int unsigned mh; + int unsigned mw; + int unsigned pe; + int unsigned simd; + int unsigned th; + int unsigned weight_width; + int unsigned activation_width; + int unsigned accu_width; + bit signed_activations; + bit narrow_weights; + } cfg_t; + + // Constraints enforced by mvu_tiled_axi: + // - MW % SIMD == 0 + // - MH % PE == 0 + // - (PE * SIMD) % TH == 0 + // - WEIGHT_WIDTH <= 8 + // - ACTIVATION_WIDTH <= 8 (9 for signed -- uses full 9-bit A port) + // - TH >= 2 (TH=1 uses the non-tiled path) + // + // Test selection rationale: + // 0: Baseline -- balanced PE/SIMD, TH=2, signed activations, narrow weights + // 1: Larger TH (=3), odd SIMD (=3) -> CHAINLEN=1 (3 lanes in one DSP) + // 2: TH = PE*SIMD (maximum tiling, WSIMD=1) -- edge case: 1 weight/cycle + // 3: High PE (=6), low SIMD (=2) -> wide PE fanout, CHAINLEN=1 + // 4: PE=MH (no replay, single neuron fold) -- tests replay bypass + // 5: Large matrix, moderate tiling -- closer to real workload + // 6: SIMD=6 (CHAINLEN=2), TH=2 -- multi-DSP chain with tiling + // 7: Unsigned activations, small bitwidths -- corner case for sign extension + // 8: TH=6 (high tiling), PE=2, SIMD=3 -- stress accumulator depth + localparam int unsigned TEST_COUNT = 9; + // mh mw pe simd th ww aw accw sa nw + localparam cfg_t TESTS[TEST_COUNT] = '{ + '{ 12, 12, 6, 3, 2, 8, 8, 24, 1, 1 }, + '{ 12, 12, 6, 3, 3, 4, 4, 16, 1, 0 }, + '{ 12, 8, 2, 4, 8, 8, 8, 24, 1, 0 }, + '{ 12, 10, 6, 2, 3, 4, 8, 20, 0, 0 }, + '{ 4, 12, 4, 3, 2, 8, 4, 20, 1, 1 }, + '{ 24, 18, 6, 6, 3, 4, 4, 18, 1, 0 }, + '{ 16, 12, 4, 6, 2, 8, 8, 24, 0, 1 }, + '{ 8, 12, 4, 3, 2, 2, 2, 12, 0, 1 }, + '{ 6, 9, 2, 3, 6, 4, 4, 16, 1, 0 } + }; + + //=== Global Control ==================================================== + logic clk = 0; + always #5ns clk = !clk; + logic clk2x = 0; + always #2.5ns clk2x = !clk2x; + + logic rst = 1; + initial begin + repeat(16) @(posedge clk); + rst <= 0; + // Allow 100ns DSP startup recovery before any input + #100ns; + end + + bit [TEST_COUNT-1:0] done = '0; + always_comb begin + if(&done) $finish; + end + + //=== Parallel Test Instantiation ======================================= + for(genvar t = 0; t < TEST_COUNT; t++) begin : genTests + localparam cfg_t CFG = TESTS[t]; + localparam int unsigned MH = CFG.mh; + localparam int unsigned MW = CFG.mw; + localparam int unsigned PE = CFG.pe; + localparam int unsigned SIMD = CFG.simd; + localparam int unsigned TH = CFG.th; + localparam int unsigned WEIGHT_WIDTH = CFG.weight_width; + localparam int unsigned ACTIVATION_WIDTH = CFG.activation_width; + localparam int unsigned ACCU_WIDTH = CFG.accu_width; + + // Derived + localparam int unsigned SF = MW / SIMD; // SIMD folds + localparam int unsigned NF = MH / PE; // neuron folds + localparam int unsigned WSIMD = (PE * SIMD) / TH; + + typedef logic signed [WEIGHT_WIDTH -1:0] weight_t; + typedef logic [ACTIVATION_WIDTH-1:0] activation_t; + typedef logic signed [ACCU_WIDTH -1:0] accu_t; + + // Stream widths (matching mvu_tiled_axi localparams) + localparam int unsigned WEIGHT_STREAM_WIDTH = WSIMD * WEIGHT_WIDTH; + localparam int unsigned WEIGHT_STREAM_WIDTH_BA = (WEIGHT_STREAM_WIDTH + 7)/8 * 8; + localparam int unsigned INPUT_STREAM_WIDTH = SIMD * ACTIVATION_WIDTH; + localparam int unsigned INPUT_STREAM_WIDTH_BA = (INPUT_STREAM_WIDTH + 7)/8 * 8; + localparam int unsigned OUTPUT_STREAM_WIDTH = PE * ACCU_WIDTH; + localparam int unsigned OUTPUT_STREAM_WIDTH_BA = (OUTPUT_STREAM_WIDTH + 7)/8 * 8; + + // DUT signals + logic [WEIGHT_STREAM_WIDTH_BA-1:0] wdat; + logic wvld; + uwire wrdy; + logic [INPUT_STREAM_WIDTH_BA-1:0] idat; + logic ivld; + uwire irdy; + uwire [OUTPUT_STREAM_WIDTH_BA-1:0] odat; + uwire ovld; + logic ordy; + + mvu_tiled_axi #( + .PE(PE), .SIMD(SIMD), + .WEIGHT_WIDTH(WEIGHT_WIDTH), + .ACTIVATION_WIDTH(ACTIVATION_WIDTH), + .ACCU_WIDTH(ACCU_WIDTH), + .MW(MW), .MH(MH), .TH(TH), + .SIGNED_ACTIVATIONS(CFG.signed_activations), + .NARROW_WEIGHTS(CFG.narrow_weights), + .PUMPED_COMPUTE(0), + .FORCE_BEHAVIORAL(0) + ) dut ( + .ap_clk(clk), + .ap_clk2x(clk2x), + .ap_rst_n(!rst), + .s_axis_weights_tdata(wdat), + .s_axis_weights_tvalid(wvld), + .s_axis_weights_tready(wrdy), + .s_axis_input_tdata(idat), + .s_axis_input_tvalid(ivld), + .s_axis_input_tready(irdy), + .m_axis_output_tdata(odat), + .m_axis_output_tvalid(ovld), + .m_axis_output_tready(ordy) + ); + + //=== Input Feed & Reference Generation ============================= + // TH input vectors are batched per round. The replay buffer + // stores TH*SF activation words (TH vectors, each SF folds) + // and the weight buffer replays each tile TH times internally. + // + // Output reorder order: for each TH slot, all NF neuron + // folds in sequence. + accu_t [PE-1:0] Q[$]; + initial begin + wdat = 'x; wvld = 0; + idat = 'x; ivld = 0; + @(posedge clk iff !rst); + + // Wait for DSP startup recovery + repeat(20) @(posedge clk); + + repeat(ROUNDS) begin + // TH activation vectors per batch + automatic activation_t [TH-1:0][MW-1:0] ivecs; + automatic weight_t [MH-1:0][MW-1:0] iwgt; + automatic accu_t [TH-1:0][MH-1:0] ovecs; + + // Randomize all inputs + void'(std::randomize(ivecs, iwgt)); + + // Sanitize weights (narrow + overflow) using first vector + for(int unsigned h = 0; h < MH; h++) begin + automatic accu_t p = 0; + for(int unsigned w = 0; w < MW; w++) begin + automatic weight_t w0 = iwgt[h][w]; + automatic accu_t m0, p0; + + if(CFG.narrow_weights && (w0 == weight_t'(1 << (WEIGHT_WIDTH-1)))) w0++; + m0 = w0 * $signed({CFG.signed_activations && ivecs[0][w][ACTIVATION_WIDTH-1], ivecs[0][w]}); + p0 = p + m0; + if(((m0 < 0) == (p < 0)) && ((m0 < 0) != (p0 < 0))) w0 = 0; + else p = p0; + + iwgt[h][w] = w0; + end + end + + // Compute golden reference for each of TH vectors + for(int unsigned y = 0; y < TH; y++) begin + for(int unsigned h = 0; h < MH; h++) begin + automatic accu_t p = 0; + for(int unsigned w = 0; w < MW; w++) begin + p += $signed(iwgt[h][w]) * $signed({CFG.signed_activations && ivecs[y][w][ACTIVATION_WIDTH-1], ivecs[y][w]}); + end + ovecs[y][h] = p; + end + end + + // Enqueue expected outputs in reorder_out order: + // for each TH slot, all NF neuron folds + for(int unsigned y = 0; y < TH; y++) begin + for(int unsigned h = 0; h < MH; h += PE) begin + Q.push_back(ovecs[y][h+:PE]); + end + end + + // Feed activations and weights concurrently + fork + //-- Activation feed -- + // Replay buffer write order: X (SF) inner, Y (TH) outer. + // Feed TH vectors, each SF SIMD-wide words. + begin : blkActFeed + for(int unsigned y = 0; y < TH; y++) begin + for(int unsigned x = 0; x < SF; x++) begin + while($urandom()%19 == 0) @(posedge clk); + idat <= ivecs[y][x*SIMD +: SIMD]; + ivld <= 1; + @(posedge clk iff irdy); + idat <= 'x; + ivld <= 0; + end + end + end : blkActFeed + + //-- Weight feed -- + // One weight matrix, chunked: for each NF, for each SF, + // send TH chunks of WSIMD weights. + begin : blkWgtFeed + for(int unsigned h = 0; h < MH; h += PE) begin + for(int unsigned w = 0; w < MW; w += SIMD) begin + // Build full PE*SIMD weight tile + automatic weight_t [PE-1:0][SIMD-1:0] wtile; + for(int unsigned pe = 0; pe < PE; pe++) begin + for(int unsigned simd = 0; simd < SIMD; simd++) begin + wtile[pe][simd] = iwgt[h+pe][w+simd]; + end + end + + // Slice into TH chunks of WSIMD weights + for(int unsigned chunk = 0; chunk < TH; chunk++) begin + automatic logic [WEIGHT_STREAM_WIDTH_BA-1:0] wword = '0; + for(int unsigned k = 0; k < WSIMD; k++) begin + automatic int unsigned flat_idx = chunk * WSIMD + k; + automatic int unsigned pe_idx = flat_idx / SIMD; + automatic int unsigned simd_idx = flat_idx % SIMD; + wword[k*WEIGHT_WIDTH +: WEIGHT_WIDTH] = wtile[pe_idx][simd_idx]; + end + + while($urandom()%23 == 0) @(posedge clk); + wdat <= wword; + wvld <= 1; + @(posedge clk iff wrdy); + wdat <= 'x; + wvld <= 0; + end + end + end + end : blkWgtFeed + join + end + + repeat(256) @(posedge clk); + assert(Q.size == 0) else begin + $error("Test #%0d: Missing %0d outputs.", t, Q.size); + $stop; + end + done[t] = 1; + end + + //=== Output Checker ================================================ + int unsigned Checks = 0; + initial begin + ordy = 0; + @(posedge clk iff !rst); + + forever begin + automatic accu_t [PE-1:0] exp; + automatic accu_t [PE-1:0] p; + + while(($urandom() % 59) == 0) @(posedge clk); + + // Drain one output + ordy <= 1; + @(posedge clk iff ovld); + ordy <= 0; + + p = odat; + assert(Q.size > 0) else begin + $error("Test #%0d: Spurious output: %0p.", t, p); + $stop; + end + + exp = Q.pop_front(); + assert(p === exp) else begin + $error("Test #%0d: Output mismatch %0p instead of %0p.", t, p, exp); + $stop; + end + + Checks <= Checks + 1; + end + end + + final begin + assert(Checks == ROUNDS * NF * TH) + $display("Test #%0d: OK -- %0d checks (MH=%0d MW=%0d PE=%0d SIMD=%0d TH=%0d).", + t, Checks, MH, MW, PE, SIMD, TH); + else + $error("Test #%0d: Unexpected check count: %0d instead of %0d.", t, Checks, ROUNDS * NF * TH); + end + + end : genTests + +endmodule : mvu_tiled_axi_tb diff --git a/finn-rtllib/mvu_tiled/weights_buff_tile.sv b/finn-rtllib/mvu_tiled/weights_buff_tile.sv new file mode 100644 index 0000000000..f959140cd3 --- /dev/null +++ b/finn-rtllib/mvu_tiled/weights_buff_tile.sv @@ -0,0 +1,225 @@ +/****************************************************************************** + * Copyright (C) 2024, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + *****************************************************************************/ + +module weights_buff_tile #( + int unsigned WEIGHT_WIDTH = 8, + int unsigned SIMD, + int unsigned PE, + int unsigned TH, + int unsigned WSIMD, + int unsigned NW = (PE*SIMD)/WSIMD, + int unsigned N_DCPL_STAGES +)( + input logic clk, + input logic rst, + + input logic ivld, + output logic irdy, + input logic [WSIMD-1:0][WEIGHT_WIDTH-1:0] idat, + + output logic ovld, + input logic ordy, + output logic [PE-1:0][SIMD-1:0][WEIGHT_WIDTH-1:0] odat +); + + //=== Parameter Validation ============================================== + initial begin + if((PE*SIMD) % WSIMD != 0) begin + $error("Weight stream width not set properly (WSIMD: %0d, PE %0d, SIMD %0d).", WSIMD, PE, SIMD); + $finish; + end + end + + //=== Constants and Types =============================================== + localparam int unsigned NW_BITS = (NW == 1)? 1 : $clog2(NW); + localparam int unsigned TH_BITS = (TH == 1)? 1 : $clog2(TH); + + typedef enum logic [1:0] {ST_WR_0, ST_WR_0_WAIT, ST_WR_1, ST_WR_1_WAIT} state_wr_e; + typedef enum logic {ST_RD_0, ST_RD_1} state_rd_e; + + //=== Input Slice ======================================================= + uwire ivld_int; + logic irdy_int; + uwire [WSIMD-1:0][WEIGHT_WIDTH-1:0] idat_int; + + skid #(.DATA_WIDTH(WSIMD*WEIGHT_WIDTH), .FEED_STAGES(1)) inst_ireg ( + .clk(clk), .rst(rst), + .ivld(ivld), .irdy(irdy), .idat(idat), + .ovld(ivld_int), .ordy(irdy_int), .odat(idat_int) + ); + + //=== Writer ============================================================ + state_wr_e StateWr = ST_WR_0; + state_wr_e state_wr_n; + state_rd_e StateRd = ST_RD_0; + state_rd_e state_rd_n; + + logic [NW_BITS-1:0] Curr = '0; + logic [NW_BITS-1:0] curr_n; + + logic done; + + logic ovld_int; + logic ordy_int; + logic [PE-1:0][SIMD-1:0][WEIGHT_WIDTH-1:0] odat_int; + + logic [1:0][NW-1:0][WSIMD*WEIGHT_WIDTH-1:0] Mem = '0; + logic [1:0][NW-1:0][WSIMD*WEIGHT_WIDTH-1:0] mem_n; + + //--- Writer Sequential ------------------------------------------------- + always_ff @(posedge clk) begin + if(rst) begin + StateWr <= ST_WR_0; + Curr <= '0; + Mem <= '0; + end + else begin + StateWr <= state_wr_n; + Curr <= curr_n; + Mem <= mem_n; + end + end + + //--- Writer Next State ------------------------------------------------- + always_comb begin + state_wr_n = StateWr; + + case(StateWr) + ST_WR_0: + if((Curr == NW - 1) && ivld_int) + state_wr_n = (done || (StateRd == ST_RD_0))? ST_WR_1 : ST_WR_0_WAIT; + + ST_WR_0_WAIT: + state_wr_n = (done || (StateRd == ST_RD_0))? ST_WR_1 : ST_WR_0_WAIT; + + ST_WR_1: + if((Curr == NW - 1) && ivld_int) + state_wr_n = (done || (StateRd == ST_RD_1))? ST_WR_0 : ST_WR_1_WAIT; + + ST_WR_1_WAIT: + state_wr_n = (done || (StateRd == ST_RD_1))? ST_WR_0 : ST_WR_1_WAIT; + endcase + end + + //--- Writer Datapath --------------------------------------------------- + always_comb begin + curr_n = Curr; + mem_n = Mem; + irdy_int = 0; + + case(StateWr) + ST_WR_0, ST_WR_1: begin + irdy_int = 1; + + if(ivld_int) begin + if(StateWr == ST_WR_0) begin + mem_n[0] = (Mem[0] >> WSIMD*WEIGHT_WIDTH); + mem_n[0][NW-1] = idat_int; + end + else begin + mem_n[1] = (Mem[1] >> WSIMD*WEIGHT_WIDTH); + mem_n[1][NW-1] = idat_int; + end + + curr_n = (Curr == NW-1)? 0 : Curr + 1; + end + end + endcase + end + + //=== Reader ============================================================ + logic [TH_BITS-1:0] ConsR = '0; + logic [TH_BITS-1:0] cons_r_n; + + //--- Reader Sequential ------------------------------------------------- + always_ff @(posedge clk) begin + if(rst) begin + StateRd <= ST_RD_0; + ConsR <= 0; + end + else begin + StateRd <= state_rd_n; + ConsR <= cons_r_n; + end + end + + //--- Reader Next State ------------------------------------------------- + always_comb begin + state_rd_n = StateRd; + + case(StateRd) + ST_RD_0: + if(ordy_int && (StateWr != ST_WR_0)) + if(ConsR == TH-1) + state_rd_n = ST_RD_1; + + ST_RD_1: + if(ordy_int && (StateWr != ST_WR_1)) + if(ConsR == TH-1) + state_rd_n = ST_RD_0; + endcase + end + + //--- Reader Datapath --------------------------------------------------- + always_comb begin + cons_r_n = ConsR; + done = 0; + ovld_int = 0; + odat_int = 0; + + case(StateRd) + ST_RD_0: + if(ordy_int && (StateWr != ST_WR_0)) begin + ovld_int = 1; + odat_int = Mem[0]; + done = (ConsR == TH-1); + cons_r_n = (ConsR == TH-1)? 0 : ConsR + 1; + end + + ST_RD_1: + if(ordy_int && (StateWr != ST_WR_1)) begin + ovld_int = 1; + odat_int = Mem[1]; + done = (ConsR == TH-1); + cons_r_n = (ConsR == TH-1)? 0 : ConsR + 1; + end + endcase + end + + //=== Output Slice ====================================================== + skid #(.DATA_WIDTH(PE*SIMD*WEIGHT_WIDTH), .FEED_STAGES(N_DCPL_STAGES)) inst_oreg ( + .clk(clk), .rst(rst), + .ivld(ovld_int), .irdy(ordy_int), .idat(odat_int), + .ovld(ovld), .ordy(ordy), .odat(odat) + ); + +endmodule : weights_buff_tile diff --git a/finn-rtllib/ram/ram_p_c.sv b/finn-rtllib/ram/ram_p_c.sv deleted file mode 100644 index 553121f2f8..0000000000 --- a/finn-rtllib/ram/ram_p_c.sv +++ /dev/null @@ -1,76 +0,0 @@ -/****************************************************************************** - * Copyright (C) 2024, Advanced Micro Devices, Inc. - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, - * this list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, - * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR - * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR - * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, - * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, - * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, - * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR - * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF - * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - *****************************************************************************/ - -module ram_p_c #( - int unsigned ADDR_BITS, - int unsigned DATA_BITS, - parameter RAM_STYLE = "block" -) ( - input logic clk, - input logic a_en, - input logic [(DATA_BITS/8)-1:0] a_we, - input logic [ADDR_BITS-1:0] a_addr, - input logic b_en, - input logic [ADDR_BITS-1:0] b_addr, - input logic [DATA_BITS-1:0] a_data_in, - output logic [DATA_BITS-1:0] a_data_out, - output logic [DATA_BITS-1:0] b_data_out -); - - localparam int unsigned DEPTH = 2**ADDR_BITS; - - (* ram_style = RAM_STYLE *) logic [DATA_BITS-1:0] ram[DEPTH]; - - reg [DATA_BITS-1:0] a_data_reg = 0; - reg [DATA_BITS-1:0] b_data_reg = 0; - - reg [DATA_BITS-1:0] a_data_q = 0; - reg [DATA_BITS-1:0] b_data_q = 0; - - always_ff @(posedge clk) begin - if(a_en) begin - for (int i = 0; i < (DATA_BITS/8); i++) begin - if(a_we[i]) begin - ram[a_addr][(i*8)+:8] <= a_data_in[(i*8)+:8]; - end - end - a_data_reg <= ram[a_addr]; - a_data_out <= a_data_reg; - end - if(b_en) begin - b_data_reg <= ram[b_addr]; - b_data_out <= b_data_reg; - end - //end - end - -endmodule : ram_p_c diff --git a/finn_xsi/finn_xsi/sim_engine.py b/finn_xsi/finn_xsi/sim_engine.py index 0d17e581af..7288650624 100644 --- a/finn_xsi/finn_xsi/sim_engine.py +++ b/finn_xsi/finn_xsi/sim_engine.py @@ -109,7 +109,7 @@ def run(self, cycles=None): # Execute Cycle self.ticks += 1 - print(f"Cycle {self.ticks}") + # print(f"Cycle {self.ticks}") strong = False for task in self.tasks: # Tasks read signals and derive updates to schedule for after the clock cycle diff --git a/src/finn/core/rtlsim_exec.py b/src/finn/core/rtlsim_exec.py index b734a181e5..db19e3f730 100644 --- a/src/finn/core/rtlsim_exec.py +++ b/src/finn/core/rtlsim_exec.py @@ -26,6 +26,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import json import numpy as np import os from qonnx.custom_op.registry import getCustomOp @@ -358,6 +359,22 @@ def rtlsim_exec_finnxsi(model, execution_context, pre_hook=None, post_hook=None) # reset and call rtlsim, including any pre/post hooks finnxsi.reset_rtlsim(sim) + + # automatically load AXI-MM weight images for external_mem nodes + aximm_weights_json = model.get_metadata_prop("vivado_stitch_aximm_weights") + if aximm_weights_json is not None: + aximm_weights = json.loads(aximm_weights_json) + for aximm_name, npy_path in aximm_weights.items(): + weight_npy = np.load(npy_path) + # Pack weight values (int8 etc.) into a flat byte array for AXI-MM + # Each element is one weight; pack them LSB-first per line + weight_bytes = [] + for line in weight_npy.reshape(-1, weight_npy.shape[-1]): + for val in line: + weight_bytes.append(int(val) & 0xFF) + weight_data = np.array(weight_bytes, dtype=np.uint8) + sim.aximm_ro_image(aximm_name, 0, weight_data.flatten()) + if pre_hook is not None: pre_hook(sim) n_cycles = finnxsi.rtlsim_multi_io( diff --git a/src/finn/custom_op/fpgadataflow/hls/matrixvectoractivation_hls.py b/src/finn/custom_op/fpgadataflow/hls/matrixvectoractivation_hls.py index 6abafc8eec..902808f5c8 100644 --- a/src/finn/custom_op/fpgadataflow/hls/matrixvectoractivation_hls.py +++ b/src/finn/custom_op/fpgadataflow/hls/matrixvectoractivation_hls.py @@ -139,11 +139,10 @@ def dsp_estimation(self, fpgapart): def code_generation_ipgen(self, model, fpgapart, clk): """Generates c++ code and tcl script for ip generation.""" super().code_generation_ipgen(model, fpgapart, clk) - dynamic_input = self.get_nodeattr("dynamic_input") mem_mode = self.get_nodeattr("mem_mode") - if dynamic_input: + if mem_mode == "dynamic": self.generate_hdl_dynload() - if mem_mode == "internal_decoupled" and not self.get_nodeattr("mlo_max_iter"): + if mem_mode == "internal_decoupled": if self.get_nodeattr("ram_style") == "ultra" and not is_versal(fpgapart): runtime_writeable = self.get_nodeattr("runtime_writeable_weights") assert ( @@ -151,7 +150,7 @@ def code_generation_ipgen(self, model, fpgapart, clk): ), """Layer with URAM weights must have runtime_writeable_weights=1 if Ultrascale device is targeted.""" self.generate_hdl_memstream(fpgapart, pumped_memory=self.get_nodeattr("pumpedMemory")) - elif self.get_nodeattr("mlo_max_iter"): + elif mem_mode == "external_mem": self.generate_hdl_fetch_weights(fpgapart) def get_template_param_values(self): @@ -236,11 +235,7 @@ def defines(self, var): numReps, ) ] - if ( - mem_mode == "internal_decoupled" - or mem_mode == "external" - or self.get_nodeattr("mlo_max_iter") - ): + if mem_mode in ["internal_decoupled", "external", "external_mem", "dynamic"]: wdt = self.get_input_datatype(1) self.code_gen_dict["$DEFINES$"].append("#define WP1 {}\n".format(wdt.bitwidth())) @@ -268,14 +263,10 @@ def read_npy_data(self): ) mem_mode = self.get_nodeattr("mem_mode") - if ( - mem_mode == "internal_decoupled" - or mem_mode == "external" - or self.get_nodeattr("mlo_max_iter") - ): + if mem_mode in ["internal_decoupled", "external", "external_mem", "dynamic"]: wdt = self.get_input_datatype(1) packed_bits = self.get_instream_width(1) - if self.get_nodeattr("dynamic_input"): + if mem_mode == "dynamic": packed_bits = packed_bits * self.get_nodeattr("SIMD") packed_hls_type = "ap_uint<%d>" % packed_bits elem_hls_type = wdt.get_hls_datatype_str() @@ -302,13 +293,9 @@ def strm_decl(self): 'hls::stream> out0_V ("out0_V");'.format(self.get_outstream_width()) ) - if ( - mem_mode == "internal_decoupled" - or mem_mode == "external" - or self.get_nodeattr("mlo_max_iter") - ): + if mem_mode in ["internal_decoupled", "external", "external_mem", "dynamic"]: iwidth = self.get_instream_width(1) - if self.get_nodeattr("dynamic_input"): + if mem_mode == "dynamic": iwidth = iwidth * self.get_nodeattr("SIMD") self.code_gen_dict["$STREAMDECLARATIONS$"].append( 'hls::stream> in1_V ("in1_V");'.format(iwidth) @@ -338,11 +325,7 @@ def docompute(self): map_to_hls_mult_style[self.get_nodeattr("resType")], ) ] - elif ( - mem_mode == "internal_decoupled" - or mem_mode == "external" - or self.get_nodeattr("mlo_max_iter") - ): + elif mem_mode in ["internal_decoupled", "external", "external_mem", "dynamic"]: wdt = self.get_input_datatype(1) if wdt == DataType["BIPOLAR"]: export_wdt = DataType["BINARY"] @@ -408,13 +391,9 @@ def blackboxfunction(self): self.get_outstream_width(), ) ] - elif ( - mem_mode == "internal_decoupled" - or mem_mode == "external" - or self.get_nodeattr("mlo_max_iter") - ): + elif mem_mode in ["internal_decoupled", "external", "external_mem", "dynamic"]: wwidth = self.get_instream_width(1) - if self.get_nodeattr("dynamic_input"): + if mem_mode == "dynamic": wwidth = wwidth * self.get_nodeattr("SIMD") self.code_gen_dict["$BLACKBOXFUNCTION$"] = [ """void {}( @@ -449,11 +428,7 @@ def pragmas(self): self.code_gen_dict["$PRAGMAS$"].append( ("#pragma HLS ARRAY_PARTITION variable=weights.m_weights " "complete dim=1") ) - elif ( - mem_mode == "internal_decoupled" - or mem_mode == "external" - or self.get_nodeattr("mlo_max_iter") - ): + elif mem_mode in ["internal_decoupled", "external", "external_mem", "dynamic"]: self.code_gen_dict["$PRAGMAS$"].append("#pragma HLS INTERFACE axis port=in1_V") else: @@ -494,7 +469,8 @@ def get_ap_int_max_w(self): # internal_decoupled mode weight stream weightstream = self.get_instream_width(1) simd = self.get_nodeattr("SIMD") - if self.get_nodeattr("dynamic_input"): + mem_mode = self.get_nodeattr("mem_mode") + if mem_mode == "dynamic": weightstream = weightstream * simd # single PE weight entry weight_bits = self.get_input_datatype(1).bitwidth() @@ -503,7 +479,6 @@ def get_ap_int_max_w(self): def execute_node(self, context, graph): mode = self.get_nodeattr("exec_mode") - dynamic_input = self.get_nodeattr("dynamic_input") mem_mode = self.get_nodeattr("mem_mode") node = self.onnx_node @@ -547,7 +522,7 @@ def execute_node(self, context, graph): ) if in_ind == 1: - if dynamic_input: + if mem_mode in ["dynamic", "external", "internal_decoupled", "external_mem"]: reshaped_input = context[inputs].reshape(-1, context[inputs].shape[-1]) self.make_weight_file( reshaped_input, "decoupled_npy", "{}/input_1.npy".format(code_gen_dir) @@ -572,13 +547,9 @@ def execute_node(self, context, graph): inp = npy_to_rtlsim_input("{}/input_0.npy".format(code_gen_dir), export_idt, nbits) self.reset_rtlsim(sim) - if ( - dynamic_input - or mem_mode in ["external", "internal_decoupled"] - or self.get_nodeattr("mlo_max_iter") - ): + if mem_mode in ["external", "internal_decoupled", "external_mem", "dynamic"]: wnbits = self.get_instream_width(1) - if self.get_nodeattr("dynamic_input"): + if mem_mode == "dynamic": wnbits = wnbits * self.get_nodeattr("SIMD") export_wdt = self.get_input_datatype(1) @@ -682,9 +653,8 @@ def instantiate_ip(self, cmd): # instantiate the HLS IP vlnv = self.get_nodeattr("ip_vlnv") node_name = self.onnx_node.name - if self.get_nodeattr("mem_mode") == "internal_decoupled" or self.get_nodeattr( - "mlo_max_iter" - ): + mem_mode = self.get_nodeattr("mem_mode") + if mem_mode in ["internal_decoupled", "external_mem", "dynamic"]: cmd.append("create_bd_cell -type ip -vlnv %s /%s/%s" % (vlnv, node_name, node_name)) else: cmd.append("create_bd_cell -type ip -vlnv %s %s" % (vlnv, node_name)) diff --git a/src/finn/custom_op/fpgadataflow/hwcustomop.py b/src/finn/custom_op/fpgadataflow/hwcustomop.py index a1c2bbe2f6..1db619cea2 100644 --- a/src/finn/custom_op/fpgadataflow/hwcustomop.py +++ b/src/finn/custom_op/fpgadataflow/hwcustomop.py @@ -348,12 +348,14 @@ def generate_hdl_memstream(self, fpgapart, pumped_memory=0): else: pass - def generate_hdl_fetch_weights(self, fpgapart): + def generate_hdl_fetch_weights(self): """Helper function to generate verilog code for fetch_weights component. Currently utilized by MVAU.""" ops = ["MVAU_hls", "MVAU_rtl"] if self.onnx_node.op_type in ops or self.onnx_node.op_type.startswith("Elementwise"): - template_path = os.environ["FINN_ROOT"] + "/finn-rtllib/mlo/fetch_weights_wrapper.v" + template_path = ( + os.environ["FINN_ROOT"] + "/finn-rtllib/fetch_weights/fetch_weights_wrapper.v" + ) mname = self.onnx_node.name wdt = self.get_input_datatype(1) if self.onnx_node.op_type in ops: @@ -361,7 +363,9 @@ def generate_hdl_fetch_weights(self, fpgapart): mh = self.get_nodeattr("MH") pe = self.get_nodeattr("PE") simd = self.get_nodeattr("SIMD") - n_reps = np.prod(self.get_nodeattr("numInputVectors")) + theight = self.get_nodeattr("TH") + n_reps = np.prod(self.get_nodeattr("numInputVectors")) // theight + en_mlo = "EN_MLO" if self.get_nodeattr("mlo_max_iter") else "NO_MLO" else: # Eltwise layers only have one parallelism parameter mw = 1 @@ -370,10 +374,24 @@ def generate_hdl_fetch_weights(self, fpgapart): simd = 1 # TODO use broadcast rhs shape here n_reps = np.prod(self.get_nodeattr("rhs_shape")[:-1]) - layer_offs = mw * mh + theight = 1 + en_mlo = "EN_MLO" if self.get_nodeattr("mlo_max_iter") else "NO_MLO" # upper bound on how many layers can be supported, set to 64 for now n_max_layers = 64 code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") + + # Compute IWSIMD and WSIMD for the fetch_weights wrapper + if self.onnx_node.op_type in ops: + if theight > 1: + iwsimd = (pe * simd) // theight + wsimd = (pe * simd) // theight + else: + iwsimd = simd + wsimd = (pe * simd) // theight + else: + iwsimd = simd + wsimd = (pe * simd) // theight + code_gen_dict = { "$MODULE_NAME_AXI_WRAPPER$": [mname + "_fetch_weights_wrapper"], "$MW$": [str(mw)], @@ -382,8 +400,12 @@ def generate_hdl_fetch_weights(self, fpgapart): "$SIMD$": [str(simd)], "$N_REPS$": [str(n_reps)], "$WEIGHT_WIDTH$": [str(wdt.bitwidth())], - "$LAYER_OFFS$": [str(layer_offs)], "$N_LAYERS$": [str(n_max_layers)], + "$TH$": [str(theight)], + "$IWSIMD$": [str(iwsimd)], + "$WSIMD$": [str(wsimd)], + "$EN_MLO$": [en_mlo], + "$DWC_MODULE_NAME$": [mname + "_dwc"], } # apply code generation to template with open(template_path, "r") as f: diff --git a/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py b/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py index 05b5454ab1..0d25732016 100644 --- a/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py +++ b/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py @@ -40,6 +40,7 @@ ) from finn.custom_op.fpgadataflow.hwcustomop import HWCustomOp +from finn.util.basic import is_versal from finn.util.data_packing import numpy_to_hls_code, pack_innermost_dim_as_hex_string # ONNX i/o tensor shape assumptions for MatrixVectorActivation: @@ -88,7 +89,7 @@ def get_nodeattr_types(self): "s", False, "internal_decoupled", - {"internal_embedded", "internal_decoupled", "external"}, + {"internal_embedded", "internal_decoupled", "external", "external_mem", "dynamic"}, ), # FPGA resource type for memories in internal_decoupled mode # auto -- let Vivado decide @@ -123,8 +124,15 @@ def get_nodeattr_types(self): # weight data from the weight FIFOs. "runtime_writeable_weights": ("i", False, 0, {0, 1}), "pumpedMemory": ("i", False, 0, {0, 1}), - # dynamic input - "dynamic_input": ("i", False, 0, {0, 1}), + # tiling + "TH": ("i", False, 1), + # MMU parameters + "gemm_type": ( + "s", + False, + "mvau", + {"mvau", "mvau_tiled"}, + ), } my_attrs.update(super().get_nodeattr_types()) return my_attrs @@ -190,6 +198,8 @@ def verify_node(self): except Exception: info_messages.append("""The required MatrixVectorActivation attributes do not exist.""") + # TODO: Verify matrix unit type + # verify the number of inputs depending on noActivation value # check noActivation value to determine the number of inputs no_act = self.get_nodeattr("noActivation") @@ -253,26 +263,24 @@ def get_output_datatype(self, ind=0): """Returns FINN DataType of output.""" return DataType[self.get_nodeattr("outputDataType")] - def get_instream_width(self, ind=0): + def get_instream_width(self, ind=0): # TODO: Hacky, need to clean these calls ... if ind == 0: i_bits = self.get_input_datatype(0).bitwidth() width = i_bits * self.get_nodeattr("SIMD") elif ind == 1: - if self.get_nodeattr("dynamic_input"): - width = ( - self.get_folded_input_shape(ind)[-1] * self.get_input_datatype(ind).bitwidth() - ) - elif ( - self.get_nodeattr("mem_mode") == "internal_decoupled" - or self.get_nodeattr("mem_mode") == "external" - or self.get_nodeattr("mlo_max_iter") - ): - pe = self.get_nodeattr("PE") - simd = self.get_nodeattr("SIMD") - wp = self.get_input_datatype(1).bitwidth() - width = pe * simd * wp - else: - width = 0 + pe = self.get_nodeattr("PE") + simd = self.get_nodeattr("SIMD") + wp = self.get_input_datatype(1).bitwidth() + mem_mode = self.get_nodeattr("mem_mode") + theight = self.get_nodeattr("TH") + + match mem_mode: + case "dynamic": + width = pe * wp + case "external" | "external_mem" | "internal_decoupled": + width = ((pe * simd) * wp) // theight + case _: + width = 0 elif ind == 2: # check if integrated thresholding and return 0 # because threshold values are always embedded @@ -297,22 +305,26 @@ def get_folded_input_shape(self, ind=0): mh = self.get_nodeattr("MH") simd = self.get_nodeattr("SIMD") pe = self.get_nodeattr("PE") + mem_mode = self.get_nodeattr("mem_mode") sf = mw // simd nf = mh // pe vecs = list(self.get_nodeattr("numInputVectors")) + n_vecs = int(np.prod(vecs)) + theight = self.get_nodeattr("TH") if ind == 0: # calculate shape of input 0 folded_input_shape = tuple(vecs + [sf, simd]) elif ind == 1: - if self.get_nodeattr("dynamic_input"): - # calculate shape of input 1 (weights dynamic) - folded_input_shape = tuple(vecs[:2] + [mw] + [nf, pe]) - elif self.get_nodeattr("mem_mode") == "external" or self.get_nodeattr("mlo_max_iter"): - # calculate shape of input 1 (weights) - folded_input_shape = tuple(vecs + [sf * nf, simd * pe]) - else: - raise Exception("Undefined input shape for requested input") + match mem_mode: + case "dynamic": + folded_input_shape = tuple(vecs[:2] + [mw] + [nf, pe]) + case "external" | "external_mem" | "internal_decoupled": + folded_input_shape = (n_vecs, sf * nf, (simd * pe) // theight) + case _: + raise Exception("Undefined input shape for requested input") + else: + raise Exception("Undefined input shape for requested input") return folded_input_shape @@ -377,7 +389,7 @@ def uram_estimation(self): (mmode == "internal_decoupled" and mstyle != "ultra") or (mmode == "internal_embedded" and self.calc_wmem() <= 128) or (mmode == "external") - or self.get_nodeattr("mlo_max_iter") + or (mmode == "external_mem") ): return 0 width_multiplier = math.ceil(mem_width / 72) @@ -407,7 +419,7 @@ def bram_estimation(self): (mmode == "internal_decoupled" and mstyle in ["distributed", "ultra"]) or (mmode == "internal_embedded" and self.calc_wmem() <= 128) or (mmode == "external") - or self.get_nodeattr("mlo_max_iter") + or (mmode == "external_mem") ): return 0 # assuming SDP mode RAMB18s (see UG573 Table 1-10) @@ -458,12 +470,15 @@ def uram_efficiency_estimation(self): def get_exp_cycles(self): pe = self.get_nodeattr("PE") simd = self.get_nodeattr("SIMD") + th = self.get_nodeattr("TH") num_inp_vec = self.get_nodeattr("numInputVectors") mh = self.get_nodeattr("MH") mw = self.get_nodeattr("MW") # since mmv != 1 is not supported yet, we set mmv for now to 1 mmv = 1 - exp_cycles = (mh / pe) * (mw / simd) * np.prod(num_inp_vec) / mmv + # Tiling/systolic reduces throughput + # TH>1 (tiling) reduces throughput by factor TH (tinner = PE*SIMD/TH) + exp_cycles = (mh / pe) * (mw / simd) * np.prod(num_inp_vec) * th / mmv return int(exp_cycles) def minimize_accumulator_width(self, model): @@ -477,13 +492,12 @@ def minimize_accumulator_width(self, model): idt = self.get_input_datatype(0) - # if runtime-writeable weights or mem_mode=external, then the values of the weights can - # change and we need to use the worst-case values from the datatypes + # if runtime-writeable weights, mem_mode=external, or weights are absent (MLO), + # then we need to use the worst-case values from the datatypes if ( self.get_nodeattr("runtime_writeable_weights") - or self.get_nodeattr("mem_mode") == "external" - or self.get_nodeattr("mlo_max_iter") - or self.get_nodeattr("dynamic_input") + or self.get_nodeattr("mem_mode") in ["external", "external_mem", "dynamic"] + or weights is None ): mw = self.get_nodeattr("MW") mh = self.get_nodeattr("MH") @@ -530,11 +544,11 @@ def minimize_weight_bit_width(self, model): """Minimize the bit width based on the values of the weights.""" if not ( self.get_nodeattr("runtime_writeable_weights") - or self.get_nodeattr("mem_mode") == "external" - or self.get_nodeattr("mlo_max_iter") - or self.get_nodeattr("dynamic_input") + or self.get_nodeattr("mem_mode") in ["external", "external_mem", "dynamic"] ): weights = model.get_initializer(self.onnx_node.input[1]) + if weights is None: + return DataType[self.get_nodeattr("weightDataType")] w_min = weights.min() w_max = weights.max() if w_min < 0: @@ -696,13 +710,17 @@ def make_weight_file(self, weights, weight_file_mode, weight_file_name): # flipped weight_tensor_pe_flipped = weight_tensor_pe_flipped.reshape(1, -1, pe * simd) weight_tensor_pe_flipped = weight_tensor_pe_flipped.copy() + # tiling + tinner = (pe * simd) // self.get_nodeattr("TH") + weight_tensor_simd_flipped = weight_tensor_simd_flipped.reshape(1, -1, tinner) + weight_tensor_pe_flipped = weight_tensor_pe_flipped.reshape(1, -1, tinner) if weight_file_mode == "decoupled_npy": # save weight stream into npy for cppsim np.save(weight_file_name, weight_tensor_simd_flipped) elif weight_file_mode == "decoupled_verilog_dat": # convert weight values into hexstring weight_width = self.get_instream_width(1) - if self.get_nodeattr("dynamic_input"): + if self.get_nodeattr("mem_mode") == "dynamic": weight_width = weight_width * simd # pad to nearest 4 bits to get hex strings weight_width_padded = roundup_to_integer_multiple(weight_width, 4) @@ -734,7 +752,7 @@ def make_weight_file(self, weights, weight_file_mode, weight_file_name): # memstream axi-lite interface will map each mem line to # one or multiple 32-bit words weight_width = self.get_instream_width(1) - if self.get_nodeattr("dynamic_input"): + if self.get_nodeattr("mem_mode") == "dynamic": weight_width = weight_width * simd words_per_memwidth = 2 ** math.ceil(math.log2(weight_width / 32)) if words_per_memwidth < 1: @@ -765,25 +783,22 @@ def generate_params(self, model, path): # weights, if not external weights = model.get_initializer(self.onnx_node.input[1]) if weights is not None: - if mem_mode == "internal_embedded": - # save hlslib-compatible weights in params.h - weight_filename = "{}/params.h".format(code_gen_dir) - self.make_weight_file(weights, "hls_header", weight_filename) - elif mem_mode == "internal_decoupled" or mem_mode == "external": - weight_filename_sim = "{}/input_1.npy".format(code_gen_dir) - # save internal_decoupled weights for cppsim - self.make_weight_file(weights, "decoupled_npy", weight_filename_sim) - if mem_mode == "internal_decoupled": + match mem_mode: + case "internal_embedded": + # save hlslib-compatible weights in params.h + weight_filename = "{}/params.h".format(code_gen_dir) + self.make_weight_file(weights, "hls_header", weight_filename) + case "internal_decoupled" | "external" | "external_mem": + weight_filename_sim = "{}/input_1.npy".format(code_gen_dir) + # save internal_decoupled weights for cppsim + self.make_weight_file(weights, "decoupled_npy", weight_filename_sim) + # if mem_mode == "internal_decoupled": # also save weights as Verilog .dat file # This file will be ignored when synthesizing UltraScale memory. weight_filename_rtl = "{}/memblock.dat".format(code_gen_dir) self.make_weight_file(weights, "decoupled_verilog_dat", weight_filename_rtl) else: - if not ( - mem_mode == "external" - or self.get_nodeattr("mlo_max_iter") - or self.get_nodeattr("dynamic_input") - ): + if mem_mode not in ["external", "dynamic", "external_mem"]: raise Exception( """Invalid setting found, weight values not initialized, but neither "external" case nor MLO.""" @@ -897,33 +912,48 @@ def get_verilog_top_module_intf_names(self): if pumped_compute or self.get_nodeattr("pumpedMemory"): intf_names["clk2x"] = ["ap_clk2x"] - if self.get_nodeattr("mlo_max_iter"): - intf_names["aximm"].append(("axi_mm", 64)) - intf_names["s_axis"].append(("in_idx0_V", 32)) - else: - dynamic_input = self.get_nodeattr("dynamic_input") - mem_mode = self.get_nodeattr("mem_mode") - if dynamic_input: - weight_width = self.get_instream_width(1) - weight_width = weight_width * self.get_nodeattr("SIMD") - intf_names["s_axis"].append(("in1_V", roundup_to_integer_multiple(weight_width, 8))) - else: - if mem_mode == "external": - intf_names["s_axis"].append(("in1_V", self.get_instream_width_padded(1))) - elif mem_mode == "internal_decoupled": - # only expose axilite interface if attribute is set - runtime_writeable = self.get_nodeattr("runtime_writeable_weights") - if runtime_writeable: - intf_names["axilite"] = ["s_axilite"] + match self.get_nodeattr("mem_mode"): + case "external_mem": + intf_names["aximm"].append(("axi_mm", 64)) + if self.get_nodeattr("mlo_max_iter") > 0: + intf_names["s_axis"].append(("in_idx0_V", 32)) + case "dynamic" | "external": + intf_names["s_axis"].append(("in1_V", self.get_instream_width_padded(1))) + case "internal_decoupled": + # only expose axilite interface if attribute is set + if self.get_nodeattr("runtime_writeable_weights"): + intf_names["axilite"] = ["s_axilite"] + return intf_names + def generate_hdl(self, fpgapart): + mem_mode = self.get_nodeattr("mem_mode") + + match mem_mode: + case "dynamic": + self.generate_hdl_dynload() + case "external_mem": + self.generate_hdl_fetch_weights() + case "internal_decoupled": + if self.get_nodeattr("ram_style") == "ultra" and not is_versal(fpgapart): + assert ( + self.get_nodeattr("runtime_writeable_weights") == 1 + ), """Layer with URAM weights must have runtime_writeable_weights=1 + if Ultrascale device is targeted.""" + self.generate_hdl_memstream( + fpgapart, pumped_memory=self.get_nodeattr("pumpedMemory") + ) + def code_generation_ipi(self): source_target = "./ip/verilog/rtl_ops/%s" % self.onnx_node.name cmd = ["file mkdir %s" % source_target] - dyn_input = self.get_nodeattr("dynamic_input") - mem_mode = self.get_nodeattr("mem_mode") + + # # check if additional components are needed - if mem_mode == "internal_decoupled" or self.get_nodeattr("mlo_max_iter") or dyn_input: + mem_mode = self.get_nodeattr("mem_mode") + if mem_mode in ["internal_decoupled", "dynamic", "external_mem"]: + # + # Base runtime_writeable = self.get_nodeattr("runtime_writeable_weights") node_name = self.onnx_node.name # create a hierarchy for this layer, with the same port names @@ -955,133 +985,174 @@ def code_generation_ipi(self): "-vlnv xilinx.com:interface:axis_rtl:1.0 /%s/%s" % (node_name, din_name) ) - if self.get_nodeattr("mlo_max_iter"): - cmd.append( - "create_bd_intf_pin -mode Slave " - "-vlnv xilinx.com:interface:axis_rtl:1.0 /%s/%s" % (node_name, "in_idx0_V") - ) - cmd.append( - "create_bd_intf_pin -mode Master " - "-vlnv xilinx.com:interface:aximm_rtl:1.0 /%s/%s" % (node_name, "axi_mm") - ) - + # # Instantiate either the HLS or RTL IP depending on operator self.instantiate_ip(cmd) code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") - if dyn_input: - # additional dynamic input - win_name = self.get_verilog_top_module_intf_names()["s_axis"][1][0] - cmd.append( - "create_bd_intf_pin -mode Slave " - "-vlnv xilinx.com:interface:axis_rtl:1.0 /%s/%s" % (node_name, win_name) - ) - # dynamic loader - ram_rtllib_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/ram/") - dyn_rtllib_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/dynload/hdl/") - file_suffix = "_dynamic_load_wrapper.v" - # automatically find memstream verilog component in code generation directory - for fname in os.listdir(code_gen_dir): - if fname.endswith(file_suffix): - strm_tmpl = fname - strm_tmpl_name = strm_tmpl[:-2] - sourcefiles = [ - os.path.join(code_gen_dir, strm_tmpl), - ram_rtllib_dir + "ram_p_c.sv", - dyn_rtllib_dir + "dynamic_load.sv", - ] - for f in sourcefiles: - cmd += ["add_files -copy_to %s -norecurse %s" % (source_target, f)] - strm_inst = node_name + "_wdynld" - strm_out_name = "m_axis_0" - elif self.get_nodeattr("mlo_max_iter"): - # instantiate a fetch weights component and connect it to the IP - mlo_rtllib_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/mlo/") - reg_rtllib_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/skid/") - ram_rtllib_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/ram/") - dwc_rtllib_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/dwc/hdl/") - dma_rtllib_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/cdma/") - file_suffix = "_fetch_weights_wrapper.v" - # automatically find memstream verilog component in code generation directory - for fname in os.listdir(code_gen_dir): - if fname.endswith(file_suffix): - strm_tmpl = fname - strm_tmpl_name = strm_tmpl[:-2] - sourcefiles = [ - os.path.join(code_gen_dir, strm_tmpl), - reg_rtllib_dir + "skid.sv", - ram_rtllib_dir + "ram_p_c.sv", - dwc_rtllib_dir + "axis_adapter.v", - dwc_rtllib_dir + "axis_fifo_adapter.sv", - dwc_rtllib_dir + "axis_fifo.v", - mlo_rtllib_dir + "fetch_weights.sv", - mlo_rtllib_dir + "local_weight_buffer.sv", - ] - # add files from cdma dir - for file in os.listdir(dma_rtllib_dir): - if file.endswith(".sv") or file.endswith(".svh"): - sourcefiles.append(os.path.join(dma_rtllib_dir, file)) - for file in os.listdir(dma_rtllib_dir + "cdma_a/"): - if file.endswith(".sv") or file.endswith(".svh"): - sourcefiles.append(os.path.join(dma_rtllib_dir + "cdma_a/", file)) - for file in os.listdir(dma_rtllib_dir + "cdma_u/"): - if file.endswith(".sv") or file.endswith(".svh"): - sourcefiles.append(os.path.join(dma_rtllib_dir + "cdma_u/", file)) - for file in os.listdir(dma_rtllib_dir + "cdma_x/"): - if file.endswith(".sv") or file.endswith(".svh"): - sourcefiles.append(os.path.join(dma_rtllib_dir + "cdma_x/", file)) - - for f in sourcefiles: - cmd += ["add_files -copy_to %s -norecurse %s" % (source_target, f)] - strm_inst = node_name + "_fetch_weights" - strm_out_name = "out0_V" - # update intf dict to remove weights input and replace with index/tap input - self.get_verilog_top_module_intf_names()["s_axis"] - - elif mem_mode == "internal_decoupled": - # instantiate a streamer and connect it to the IP - axi_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/axi/hdl/") - ms_rtllib_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/memstream/hdl/") - file_suffix = "_memstream_wrapper.v" - # automatically find memstream verilog component in code generation directory - for fname in os.listdir(code_gen_dir): - if fname.endswith(file_suffix): - strm_tmpl = fname - strm_tmpl_name = strm_tmpl[:-2] - sourcefiles = [ - os.path.join(code_gen_dir, strm_tmpl), - axi_dir + "axilite.sv", - ms_rtllib_dir + "memstream_axi.sv", - ms_rtllib_dir + "memstream.sv", - ] - for f in sourcefiles: - cmd += ["add_files -copy_to %s -norecurse %s" % (source_target, f)] - strm_inst = node_name + "_wstrm" - strm_out_name = "m_axis_0" + + match mem_mode: + # + # Dynamic loader instantiation + case "dynamic": + # additional dynamic input + win_name = self.get_verilog_top_module_intf_names()["s_axis"][1][0] + cmd.append( + "create_bd_intf_pin -mode Slave " + "-vlnv xilinx.com:interface:axis_rtl:1.0 /%s/%s" % (node_name, win_name) + ) + + # dynamic loader + ram_rtllib_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/ram/") + dyn_rtllib_dir = os.path.join( + os.environ["FINN_ROOT"], "finn-rtllib/dynload/hdl/" + ) + file_suffix = "_dynamic_load_wrapper.v" + # automatically find memstream verilog component in code generation directory + for fname in os.listdir(code_gen_dir): + if fname.endswith(file_suffix): + strm_tmpl = fname + strm_tmpl_name = strm_tmpl[:-2] + sourcefiles = [ + os.path.join(code_gen_dir, strm_tmpl), + dyn_rtllib_dir + "dynamic_load.sv", + ] + for f in sourcefiles: + cmd += ["add_files -copy_to %s -norecurse %s" % (source_target, f)] + strm_inst = node_name + "_wdynld" + strm_out_name = "m_axis_0" + + # + # Fetch weights instantiation (MLO or TODO: tiling) + case "external_mem": + # additional inputs + cmd.append( + "create_bd_intf_pin -mode Master " + "-vlnv xilinx.com:interface:aximm_rtl:1.0 /%s/%s" % (node_name, "axi_mm") + ) + if self.get_nodeattr("mlo_max_iter") > 0: + cmd.append( + "create_bd_intf_pin -mode Slave " + "-vlnv xilinx.com:interface:axis_rtl:1.0 /%s/%s" + % (node_name, "in_idx0_V") + ) + + # instantiate a fetch weights component and connect it to the IP + ram_rtllib_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/ram/") + reg_rtllib_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/skid/") + que_rtllib_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/fifo/hdl/") + fwg_rtllib_dir = os.path.join( + os.environ["FINN_ROOT"], "finn-rtllib/fetch_weights/" + ) + dma_rtllib_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/cdma/") + file_suffix = "_fetch_weights_wrapper.v" + # automatically find memstream verilog component in code generation directory + for fname in os.listdir(code_gen_dir): + if fname.endswith(file_suffix): + strm_tmpl = fname + strm_tmpl_name = strm_tmpl[:-2] + sourcefiles = [ + os.path.join(code_gen_dir, strm_tmpl), + reg_rtllib_dir + "skid.sv", + que_rtllib_dir + "Q_srl.v", + fwg_rtllib_dir + "fetch_weights.sv", + fwg_rtllib_dir + "local_weight_buffer.sv", + ] + # Create Vivado axis_dwidth_converter IP + theight = self.get_nodeattr("TH") + wdt = self.get_input_datatype(1) + if theight > 1: + iwsimd = (self.get_nodeattr("PE") * self.get_nodeattr("SIMD")) // theight + else: + iwsimd = self.get_nodeattr("SIMD") + ds_bits_ba = ((iwsimd * wdt.bitwidth() + 7) // 8) * 8 + dwc_ip_name = node_name + "_dwc" + s_bytes = 256 // 8 + m_bytes = ds_bits_ba // 8 + cmd += [ + "create_ip -name axis_dwidth_converter -vendor xilinx.com " + "-library ip -version 1.1 -module_name %s" % dwc_ip_name, + "set_property -dict [list " + "CONFIG.S_TDATA_NUM_BYTES {%d} " + "CONFIG.M_TDATA_NUM_BYTES {%d} " + "CONFIG.HAS_TLAST {1} " + "CONFIG.HAS_TKEEP {1} " + "] [get_ips %s]" % (s_bytes, m_bytes, dwc_ip_name), + "generate_target all [get_ips %s]" % dwc_ip_name, + ] + + # add files from cdma dir + for file in os.listdir(dma_rtllib_dir): + if file.endswith(".sv") or file.endswith(".svh"): + sourcefiles.append(os.path.join(dma_rtllib_dir, file)) + for file in os.listdir(dma_rtllib_dir + "cdma_a/"): + if file.endswith(".sv") or file.endswith(".svh"): + sourcefiles.append(os.path.join(dma_rtllib_dir + "cdma_a", file)) + for file in os.listdir(dma_rtllib_dir + "cdma_u/"): + if file.endswith(".sv") or file.endswith(".svh"): + sourcefiles.append(os.path.join(dma_rtllib_dir + "cdma_u/", file)) + for file in os.listdir(dma_rtllib_dir + "cdma_x/"): + if file.endswith(".sv") or file.endswith(".svh"): + sourcefiles.append(os.path.join(dma_rtllib_dir + "cdma_x/", file)) + for f in sourcefiles: + cmd += ["add_files -copy_to %s -norecurse %s" % (source_target, f)] + strm_inst = node_name + "_fetch_weights" + strm_out_name = "out0_V" + # update intf dict to remove weights input and replace with index/tap input + self.get_verilog_top_module_intf_names()["s_axis"] + + # + # Memstream instantiation + case "internal_decoupled": + # instantiate a streamer and connect it to the IP + axi_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/axi/hdl/") + ms_rtllib_dir = os.path.join( + os.environ["FINN_ROOT"], "finn-rtllib/memstream/hdl/" + ) + file_suffix = "_memstream_wrapper.v" + # automatically find memstream verilog component in code generation directory + for fname in os.listdir(code_gen_dir): + if fname.endswith(file_suffix): + strm_tmpl = fname + strm_tmpl_name = strm_tmpl[:-2] + sourcefiles = [ + os.path.join(code_gen_dir, strm_tmpl), + axi_dir + "axilite.sv", + ms_rtllib_dir + "memstream_axi.sv", + ms_rtllib_dir + "memstream.sv", + ] + for f in sourcefiles: + cmd += ["add_files -copy_to %s -norecurse %s" % (source_target, f)] + strm_inst = node_name + "_wstrm" + strm_out_name = "m_axis_0" cmd.append( "create_bd_cell -type hier -reference %s /%s/%s" % (strm_tmpl_name, node_name, strm_inst) ) - if self.get_nodeattr("mlo_max_iter"): - cmd.append( - "connect_bd_intf_net [get_bd_intf_pins %s/%s] " - "[get_bd_intf_pins %s/%s/%s]" - % (node_name, "in_idx0_V", node_name, strm_inst, "in_idx0_V") - ) + # + # Connect + match mem_mode: + case "dynamic": + cmd.append( + "connect_bd_intf_net [get_bd_intf_pins %s/%s] " + "[get_bd_intf_pins %s/%s/s_axis_0]" + % (node_name, win_name, node_name, strm_inst) + ) - cmd.append( - "connect_bd_intf_net [get_bd_intf_pins %s/%s] " - "[get_bd_intf_pins %s/%s/%s]" - % (node_name, "axi_mm", node_name, strm_inst, "axi_mm") - ) + case "external_mem": + cmd.append( + "connect_bd_intf_net [get_bd_intf_pins %s/%s] " + "[get_bd_intf_pins %s/%s/%s]" + % (node_name, "axi_mm", node_name, strm_inst, "axi_mm") + ) + if self.get_nodeattr("mlo_max_iter") > 0: + cmd.append( + "connect_bd_intf_net [get_bd_intf_pins %s/%s] " + "[get_bd_intf_pins %s/%s/%s]" + % (node_name, "in_idx0_V", node_name, strm_inst, "in_idx0_V") + ) - if dyn_input: - cmd.append( - "connect_bd_intf_net [get_bd_intf_pins %s/%s] " - "[get_bd_intf_pins %s/%s/s_axis_0]" - % (node_name, win_name, node_name, strm_inst) - ) cmd.append( "connect_bd_intf_net [get_bd_intf_pins %s/%s/%s] " "[get_bd_intf_pins %s/%s/in1_V]" @@ -1095,11 +1166,10 @@ def code_generation_ipi(self): "connect_bd_net [get_bd_pins %s/%s] [get_bd_pins %s/%s/ap_clk]" % (node_name, clk_name, node_name, strm_inst) ) + # if using 2x pumped memory, connect the memstreamer's 2x clk input # to the 2x clock port. otherwise connect it to the regular clock port. - if mem_mode == "internal_decoupled" and not ( - self.get_nodeattr("mlo_max_iter") or dyn_input - ): + if mem_mode == "internal_decoupled": if self.get_nodeattr("pumpedMemory"): cmd.append( "connect_bd_net [get_bd_pins %s/%s] [get_bd_pins %s/%s/ap_clk2x]" @@ -1110,8 +1180,8 @@ def code_generation_ipi(self): "connect_bd_net [get_bd_pins %s/%s] [get_bd_pins %s/%s/ap_clk2x]" % (node_name, clk_name, node_name, strm_inst) ) - # runtime writeable weights - if runtime_writeable: + # runtime writeable weights (skip for MLO nodes) + if runtime_writeable and not self.get_nodeattr("mlo_max_iter"): axilite_name = self.get_verilog_top_module_intf_names()["axilite"][0] cmd.append( "create_bd_intf_pin -mode Slave " @@ -1125,6 +1195,7 @@ def code_generation_ipi(self): ) # TODO calculate and pass in segment size here cmd.append("assign_bd_address") + cmd.append( "connect_bd_net [get_bd_pins %s/%s] [get_bd_pins %s/%s/%s]" % (node_name, rst_name, node_name, node_name, rst_name) @@ -1146,11 +1217,12 @@ def code_generation_ipi(self): # save bd cmd.append("save_bd_design") - elif (mem_mode == "internal_embedded" or mem_mode == "external") and not self.get_nodeattr( - "mlo_max_iter" - ): + + elif mem_mode in ["internal_embedded", "external"]: # base class impl sufficient for internal_embedded/external modes self.instantiate_ip(cmd) + else: raise Exception("Unrecognized mem_mode for MatrixVectorActivation") + return cmd diff --git a/src/finn/custom_op/fpgadataflow/rtl/finn_loop.py b/src/finn/custom_op/fpgadataflow/rtl/finn_loop.py index dd3123b183..6909b936a5 100644 --- a/src/finn/custom_op/fpgadataflow/rtl/finn_loop.py +++ b/src/finn/custom_op/fpgadataflow/rtl/finn_loop.py @@ -423,6 +423,11 @@ def generate_params(self, model, path): ): # rename so it doesn't get overwritten shutil.move(param_file, new_param_file) + # also rename simd-flipped npy for external_mem MVAU nodes + npy_file = "{}/input_1.npy".format(path) + if os.path.isfile(npy_file): + new_npy_file = "{}/{}_input1_{}.npy".format(path, param_node.op_type, iter) + shutil.move(npy_file, new_npy_file) elif param_node.op_type.startswith("Thresholding"): # get all generated Thresholding dat files pe = inst.get_nodeattr("PE") @@ -463,6 +468,17 @@ def generate_params(self, model, path): for line in infile: outfile.write(line) os.remove(memblock_file) + # concatenate all .npy files together (simd-flipped for AXI-MM sim) + npy_parts = [] + for iter in range(iteration): + npy_file = "{}/{}_input1_{}.npy".format(path, param_node.op_type, iter) + if os.path.isfile(npy_file): + npy_parts.append(np.load(npy_file)) + os.remove(npy_file) + if npy_parts: + combined_npy = np.concatenate(npy_parts, axis=1) + npy_out = "{}/input1_{}_id_{}.npy".format(path, param_node.op_type, i + 1) + np.save(npy_out, combined_npy) # Replace the path for the dat files in the ipgen files if Eltwise # Adapted from transformations.fpgadataflow.replace_verilog_relpaths if param_node.op_type.startswith("Elementwise"): @@ -582,6 +598,40 @@ def ipgen_singlenode_code(self, fpgapart=None): vivado_stitch_proj_dir = self.get_nodeattr("code_gen_dir_ipgen") cmd = [] + + # Create Vivado axis_dwidth_converter IPs for intermediate_frames DWCs + olen_bits = self.get_outstream_width(0) + ilen_bits = self.get_instream_width(0) + data_bits = 256 + # DWC write path: body output width -> DMA width (256) + dwc_sink_s_bytes = (olen_bits + 7) // 8 + dwc_sink_m_bytes = data_bits // 8 + cmd += [ + "create_ip -name axis_dwidth_converter -vendor xilinx.com " + "-library ip -version 1.1 -module_name if_dwc_sink", + "set_property -dict [list " + "CONFIG.S_TDATA_NUM_BYTES {%d} " + "CONFIG.M_TDATA_NUM_BYTES {%d} " + "CONFIG.HAS_TLAST {1} " + "CONFIG.HAS_TKEEP {1} " + "] [get_ips if_dwc_sink]" % (dwc_sink_s_bytes, dwc_sink_m_bytes), + "generate_target all [get_ips if_dwc_sink]", + ] + # DWC read path: DMA width (256) -> body input width + dwc_source_s_bytes = data_bits // 8 + dwc_source_m_bytes = (ilen_bits + 7) // 8 + cmd += [ + "create_ip -name axis_dwidth_converter -vendor xilinx.com " + "-library ip -version 1.1 -module_name if_dwc_source", + "set_property -dict [list " + "CONFIG.S_TDATA_NUM_BYTES {%d} " + "CONFIG.M_TDATA_NUM_BYTES {%d} " + "CONFIG.HAS_TLAST {1} " + "CONFIG.HAS_TKEEP {1} " + "] [get_ips if_dwc_source]" % (dwc_source_s_bytes, dwc_source_m_bytes), + "generate_target all [get_ips if_dwc_source]", + ] + # add all the generated IP dirs to ip_repo_paths ip_dirs = ["list"] # add RTL streamer IP diff --git a/src/finn/custom_op/fpgadataflow/rtl/matrixvectoractivation_rtl.py b/src/finn/custom_op/fpgadataflow/rtl/matrixvectoractivation_rtl.py index 9cd6fc2a9d..69c8c50126 100644 --- a/src/finn/custom_op/fpgadataflow/rtl/matrixvectoractivation_rtl.py +++ b/src/finn/custom_op/fpgadataflow/rtl/matrixvectoractivation_rtl.py @@ -31,7 +31,7 @@ from finn.custom_op.fpgadataflow.matrixvectoractivation import MVAU from finn.custom_op.fpgadataflow.rtlbackend import RTLBackend -from finn.util.basic import get_dsp_block, is_versal +from finn.util.basic import get_dsp_block from finn.util.data_packing import npy_to_rtlsim_input, rtlsim_output_to_npy # ONNX i/o tensor shape assumptions for MatrixVectorActivation_rtl: @@ -58,7 +58,6 @@ def get_nodeattr_types(self): def execute_node(self, context, graph): mode = self.get_nodeattr("exec_mode") - dynamic_input = self.get_nodeattr("dynamic_input") mem_mode = self.get_nodeattr("mem_mode") node = self.onnx_node @@ -91,7 +90,7 @@ def execute_node(self, context, graph): ) if in_ind == 1: - if dynamic_input or self.get_nodeattr("mlo_max_iter"): + if mem_mode in ["dynamic", "external"]: reshaped_input = context[inputs].reshape(-1, context[inputs].shape[-1]) self.make_weight_file( reshaped_input, "decoupled_npy", "{}/input_1.npy".format(code_gen_dir) @@ -101,19 +100,15 @@ def execute_node(self, context, graph): nbits = self.get_instream_width() inp = npy_to_rtlsim_input("{}/input_0.npy".format(code_gen_dir), export_idt, nbits) super().reset_rtlsim(sim) - if ( - dynamic_input - or mem_mode in ["external", "internal_decoupled"] - or self.get_nodeattr("mlo_max_iter") - ): + if mem_mode in ["external", "dynamic", "internal_decoupled", "external_mem"]: wnbits = self.get_instream_width(1) - if dynamic_input: + if mem_mode == "dynamic": wnbits = wnbits * self.get_nodeattr("SIMD") export_wdt = self.get_input_datatype(1) wei = npy_to_rtlsim_input("{}/input_1.npy".format(code_gen_dir), export_wdt, wnbits) num_w_reps = np.prod(self.get_nodeattr("numInputVectors")) - + num_w_reps = num_w_reps // self.get_nodeattr("TH") io_dict = { "inputs": {"in0": inp, "in1": wei * num_w_reps}, "outputs": {"out0": []}, @@ -164,23 +159,43 @@ def instantiate_ip(self, cmd): # instantiate the RTL IP node_name = self.onnx_node.name code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") - rtllib_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/mvu/") - sourcefiles = [ - "mvu_pkg.sv", - "mvu_vvu_axi.sv", - "replay_buffer.sv", - "mvu.sv", - "mvu_vvu_8sx9_dsp58.sv", - "add_multi.sv", - ] + + theight = self.get_nodeattr("TH") + + if theight > 1: + rtllib_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/mvu_tiled/") + sourcefiles = [ + "../fifo/hdl/Q_srl.v", + "../skid/skid.sv", + "../mvu/mvu_pkg.sv", + "../mvu/add_multi.sv", + "mvu_tiled_axi.sv", + "cu_mvau_tiled.sv", + "acc_stage.sv", + "input_gen.sv", + "weights_buff_tile.sv", + ] + else: + rtllib_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/mvu/") + sourcefiles = [ + "mvu_pkg.sv", + "mvu_vvu_axi.sv", + "replay_buffer.sv", + "mvu.sv", + "mvu_vvu_8sx9_dsp58.sv", + "add_multi.sv", + ] sourcefiles = [ os.path.join(code_gen_dir, self.get_nodeattr("gen_top_module") + "_wrapper.v") ] + [rtllib_dir + _ for _ in sourcefiles] for f in sourcefiles: cmd.append("add_files -norecurse %s" % (f)) - mem_mode = self.get_nodeattr("mem_mode") - if mem_mode == "internal_decoupled" or self.get_nodeattr("mlo_max_iter"): + if self.get_nodeattr("mem_mode") in [ + "internal_decoupled", + "dynamic", + "external_mem", + ] or self.get_nodeattr("mlo_max_iter"): cmd.append( "create_bd_cell -type hier -reference %s /%s/%s" % ( @@ -280,9 +295,9 @@ def generate_hdl(self, model, fpgapart, clk): wdt = self.get_input_datatype(1) narrow_weights = ( 0 - if np.min(weights) == wdt.min() - or self.get_nodeattr("dynamic_input") - or (self.get_nodeattr("mlo_max_iter") > 1) + if weights is None + or np.min(weights) == wdt.min() + or self.get_nodeattr("mem_mode") in ["dynamic", "external_mem"] else 1 ) code_gen_dict["$NARROW_WEIGHTS$"] = str(narrow_weights) @@ -305,28 +320,20 @@ def generate_hdl(self, model, fpgapart, clk): ) as f: f.write(template_wrapper) - dynamic_input = self.get_nodeattr("dynamic_input") - mem_mode = self.get_nodeattr("mem_mode") + super().generate_hdl(fpgapart) - if dynamic_input: - self.generate_hdl_dynload() - elif mem_mode == "internal_decoupled" and not self.get_nodeattr("mlo_max_iter"): - if self.get_nodeattr("ram_style") == "ultra" and not is_versal(fpgapart): - runtime_writeable = self.get_nodeattr("runtime_writeable_weights") - assert ( - runtime_writeable == 1 - ), """Layer with URAM weights must have runtime_writeable_weights=1 - if Ultrascale device is targeted.""" - self.generate_hdl_memstream(fpgapart, pumped_memory=self.get_nodeattr("pumpedMemory")) - elif self.get_nodeattr("mlo_max_iter"): - self.generate_hdl_fetch_weights(fpgapart) # set ipgen_path and ip_path so that HLS-Synth transformation # and stich_ip transformation do not complain self.set_nodeattr("ipgen_path", code_gen_dir) self.set_nodeattr("ip_path", code_gen_dir) def prepare_codegen_default(self, fpgapart, clk): - template_path = os.environ["FINN_ROOT"] + "/finn-rtllib/mvu/mvu_vvu_axi_wrapper.v" + if self.get_nodeattr("TH") > 1: + template_path = ( + os.environ["FINN_ROOT"] + "/finn-rtllib/mvu_tiled/mvu_tiled_axi_wrapper.v" + ) + else: + template_path = os.environ["FINN_ROOT"] + "/finn-rtllib/mvu/mvu_vvu_axi_wrapper.v" # check if settings are valid pumped_compute = self.get_nodeattr("pumpedCompute") @@ -335,6 +342,7 @@ def prepare_codegen_default(self, fpgapart, clk): raise Exception( "Clock pumping an input of SIMD=1 is not meaningful. Please increase SIMD." ) + dsp_block = get_dsp_block(fpgapart) code_gen_dict = {} code_gen_dict["$IS_MVU$"] = [str(1)] @@ -344,6 +352,7 @@ def prepare_codegen_default(self, fpgapart, clk): code_gen_dict["$MH$"] = [str(self.get_nodeattr("MH"))] code_gen_dict["$PE$"] = [str(self.get_nodeattr("PE"))] code_gen_dict["$SIMD$"] = [str(simd)] + code_gen_dict["$TH$"] = [str(self.get_nodeattr("TH"))] code_gen_dict["$ACTIVATION_WIDTH$"] = [str(self.get_input_datatype(0).bitwidth())] code_gen_dict["$WEIGHT_WIDTH$"] = [str(self.get_input_datatype(1).bitwidth())] code_gen_dict["$ACCU_WIDTH$"] = [str(self.get_output_datatype().bitwidth())] @@ -357,26 +366,48 @@ def prepare_codegen_default(self, fpgapart, clk): def get_rtl_file_list(self, abspath=False): if abspath: code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") + "/" - rtllib_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/mvu/") + if self.get_nodeattr("TH") > 1: + rtllib_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/mvu_tiled/") + else: + rtllib_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/mvu/") else: code_gen_dir = "" rtllib_dir = "" - verilog_files = [ - "mvu_pkg.sv", - "mvu_vvu_axi.sv", - "replay_buffer.sv", - "mvu.sv", - "mvu_vvu_8sx9_dsp58.sv", - "add_multi.sv", - ] - verilog_files = [ - os.path.join(code_gen_dir, self.get_nodeattr("gen_top_module") + "_wrapper.v") - ] + [rtllib_dir + _ for _ in verilog_files] + if self.get_nodeattr("TH") > 1: + verilog_files = [ + "../fifo/hdl/Q_srl.v", + "../skid/skid.sv", + "../mvu/mvu_pkg.sv", + "../mvu/add_multi.sv", + "acc_stage.sv", + "input_gen.sv", + "weights_buff_tile.sv", + "cu_mvau_tiled.sv", + "mvu_tiled_axi.sv", + ] + verilog_files = [ + os.path.join(code_gen_dir, self.get_nodeattr("gen_top_module") + "_wrapper.v") + ] + [rtllib_dir + _ for _ in verilog_files] + else: + verilog_files = [ + "mvu_pkg.sv", + "mvu_vvu_axi.sv", + "replay_buffer.sv", + "mvu.sv", + "mvu_vvu_8sx9_dsp58.sv", + "add_multi.sv", + ] + verilog_files = [ + os.path.join(code_gen_dir, self.get_nodeattr("gen_top_module") + "_wrapper.v") + ] + [rtllib_dir + _ for _ in verilog_files] return verilog_files def get_verilog_paths(self): verilog_paths = super().get_verilog_paths() - verilog_paths.append(os.environ["FINN_ROOT"] + "/finn-rtllib/mvu") + if self.get_nodeattr("TH") > 1: + verilog_paths.append(os.environ["FINN_ROOT"] + "/finn-rtllib/mvu_tiled") + else: + verilog_paths.append(os.environ["FINN_ROOT"] + "/finn-rtllib/mvu") return verilog_paths diff --git a/src/finn/custom_op/fpgadataflow/templates.py b/src/finn/custom_op/fpgadataflow/templates.py index 6ce0cac42c..5ec9159c2b 100644 --- a/src/finn/custom_op/fpgadataflow/templates.py +++ b/src/finn/custom_op/fpgadataflow/templates.py @@ -343,11 +343,7 @@ add_files -norecurse "$::env(FINN_ROOT)/finn-rtllib/cdma/cdma_x/cdma_x.sv" add_files -norecurse "$::env(FINN_ROOT)/finn-rtllib/cdma/cdma_x/cdma_x_rd.sv" add_files -norecurse "$::env(FINN_ROOT)/finn-rtllib/cdma/cdma_x/cdma_x_wr.sv" -add_files -norecurse "$::env(FINN_ROOT)/finn-rtllib/dwc/hdl/axis_adapter.v" -add_files -norecurse "$::env(FINN_ROOT)/finn-rtllib/dwc/hdl/axis_fifo.v" -add_files -norecurse "$::env(FINN_ROOT)/finn-rtllib/dwc/hdl/axis_fifo_adapter.sv" add_files -norecurse "$::env(FINN_ROOT)/finn-rtllib/skid/skid.sv" -add_files -norecurse "$::env(FINN_ROOT)/finn-rtllib/ram/ram_p_c.sv" add_files -norecurse "$::env(FINN_ROOT)/finn-rtllib/mlo/infrastructure/intermediate_frames.sv" add_files -norecurse "$::env(FINN_ROOT)/finn-rtllib/mlo/infrastructure/mux.sv" add_files -norecurse "$::env(FINN_ROOT)/finn-rtllib/mlo/infrastructure/demux.sv" diff --git a/src/finn/transformation/fpgadataflow/convert_to_hw_layers.py b/src/finn/transformation/fpgadataflow/convert_to_hw_layers.py index f7b7beee14..789b817f69 100644 --- a/src/finn/transformation/fpgadataflow/convert_to_hw_layers.py +++ b/src/finn/transformation/fpgadataflow/convert_to_hw_layers.py @@ -1589,7 +1589,7 @@ def apply(self, model): noActivation=0, numInputVectors=list(mm_in_shape[:-1]), name="MVAU_" + n.name, - dynamic_input=W is None, + mem_mode="dynamic" if W is None else "internal_decoupled", inFIFODepths=[2, 2] if W is None else [2], ) graph.node.insert(node_ind, new_node) @@ -1621,7 +1621,7 @@ def apply(self, model): noActivation=1, numInputVectors=list(mm_in_shape[:-1]), name="MVAU_" + n.name, - dynamic_input=W is None, + mem_mode="dynamic" if W is None else "internal_decoupled", inFIFODepths=[2, 2] if W is None else [2], ) graph.node.insert(node_ind, new_node) diff --git a/src/finn/transformation/fpgadataflow/create_stitched_ip.py b/src/finn/transformation/fpgadataflow/create_stitched_ip.py index c713179e8f..e278495fb8 100644 --- a/src/finn/transformation/fpgadataflow/create_stitched_ip.py +++ b/src/finn/transformation/fpgadataflow/create_stitched_ip.py @@ -101,6 +101,7 @@ def __init__( self.run_synth = True self.signature = signature self.has_aximm = False + self.aximm_weight_files = {} self.aximm_idx = 0 self.has_m_axis = False self.m_axis_idx = 0 @@ -192,94 +193,51 @@ def connect_axi(self, node, model): ) self.intf_names["axilite"].append(ext_if_name) - if not node_inst.get_nodeattr("mlo_max_iter"): - if node.op_type == "FINNLoop": - for mm_intf_name in aximm_intf_name: - self.connect_cmds.append( - "make_bd_intf_pins_external [get_bd_intf_pins %s/%s]" - % (inst_name, mm_intf_name[0]) - ) - self.connect_cmds.append( - "set_property name %s [get_bd_intf_ports %s_0]" - % (mm_intf_name[0], mm_intf_name[0]) - ) - self.connect_cmds.append("assign_bd_address") - - if mm_intf_name[0] == "m_axi_hbm": - seg_name = "%s/%s/SEG_%s_Reg" % ( - inst_name, - mm_intf_name[0], - mm_intf_name[0], - ) - else: - seg_name = "%s/%s/SEG_%s_Reg" % ( - inst_name, - mm_intf_name[0], - mm_intf_name[0], - ) - self.connect_cmds.append( - "set_property offset 0 [get_bd_addr_segs {%s}]" % (seg_name) - ) - # TODO should propagate this information from the node instead of 256M - self.connect_cmds.append( - "set_property range 256M [get_bd_addr_segs {%s}]" % (seg_name) - ) - self.intf_names["aximm"].append((mm_intf_name[0], mm_intf_name[1])) - self.has_aximm = True - self.aximm_idx += 1 - - elif len(aximm_intf_name) != 0: - self.connect_cmds.append( - "make_bd_intf_pins_external [get_bd_intf_pins %s/%s]" - % (inst_name, aximm_intf_name[0][0]) - ) - ext_if_name = "m_axi_gmem%d" % (self.aximm_idx) - self.connect_cmds.append( - "set_property name %s [get_bd_intf_ports m_axi_gmem_0]" % ext_if_name - ) - self.connect_cmds.append("assign_bd_address") - seg_name = "%s/Data_m_axi_gmem/SEG_%s_Reg" % (inst_name, ext_if_name) - self.connect_cmds.append( - "set_property offset 0 [get_bd_addr_segs {%s}]" % (seg_name) - ) - # TODO should propagate this information from the node instead of 4G - self.connect_cmds.append( - "set_property range 4G [get_bd_addr_segs {%s}]" % (seg_name) - ) - self.intf_names["aximm"].append((ext_if_name, aximm_intf_name[0][1])) - self.has_aximm = True - self.aximm_idx += 1 - else: + is_mlo = node_inst.get_nodeattr("mlo_max_iter") + if is_mlo: self.is_mlo = True - for mm_intf_name in aximm_intf_name: + + for mm_intf_name in aximm_intf_name: + self.connect_cmds.append( + "make_bd_intf_pins_external [get_bd_intf_pins %s/%s]" % (inst_name, mm_intf_name[0]) + ) + + # Determine external interface name and address segment path + if node.op_type == "FINNLoop": + ext_if_name = mm_intf_name[0] self.connect_cmds.append( - "make_bd_intf_pins_external [get_bd_intf_pins %s/%s]" - % (inst_name, mm_intf_name[0]) + "set_property name %s [get_bd_intf_ports %s_0]" % (ext_if_name, ext_if_name) ) - # ext_if_name = "m_axi_gmem%d" % (self.aximm_idx) - # ext_if_name = f"m_axi_{inst_name}" - idx = inputs.index(node.input[1]) - ext_if_name = f"m_axi_MVAU_id_{idx}" + seg_name = "%s/%s/SEG_%s_Reg" % (inst_name, ext_if_name, ext_if_name) + else: + # Derive a unique name from graph input index or instance name + if node.input[1] in inputs: + idx = inputs.index(node.input[1]) + ext_if_name = f"m_axi_MVAU_id_{idx}" + else: + ext_if_name = f"m_axi_{inst_name}_{self.aximm_idx}" self.connect_cmds.append( "set_property name %s [get_bd_intf_ports axi_mm_0]" % (ext_if_name) ) - self.connect_cmds.append("assign_bd_address") - seg_name = "%s/%s_fetch_weights/axi_mm/SEG_%s_Reg" % ( inst_name, inst_name, ext_if_name, ) - self.connect_cmds.append( - "set_property offset 0 [get_bd_addr_segs {%s}]" % (seg_name) - ) - # TODO should propagate this information from the node instead of 256M - self.connect_cmds.append( - "set_property range 256M [get_bd_addr_segs {%s}]" % (seg_name) - ) - self.intf_names["aximm"].append((ext_if_name, mm_intf_name[1])) - self.has_aximm = True - self.aximm_idx += 1 + + self.connect_cmds.append("assign_bd_address") + self.connect_cmds.append("set_property offset 0 [get_bd_addr_segs {%s}]" % (seg_name)) + # TODO should propagate this information from the node instead of 256M + self.connect_cmds.append("set_property range 256M [get_bd_addr_segs {%s}]" % (seg_name)) + self.intf_names["aximm"].append((ext_if_name, mm_intf_name[1])) + # Track weight data files for AXI-MM simulation + if not node.op_type == "FINNLoop": + code_gen_dir = node_inst.get_nodeattr("code_gen_dir_ipgen") + npy_path = os.path.join(code_gen_dir, "input_1.npy") + if os.path.isfile(npy_path): + self.aximm_weight_files[ext_if_name] = npy_path + self.has_aximm = True + self.aximm_idx += 1 def connect_m_axis_external(self, node, idx=None): inst_name = node.name @@ -568,6 +526,10 @@ def apply(self, model): block_vlnv = "%s:%s:%s:1.0" % (block_vendor, block_library, block_name) model.set_metadata_prop("vivado_stitch_vlnv", block_vlnv) model.set_metadata_prop("vivado_stitch_ifnames", json.dumps(self.intf_names)) + if self.aximm_weight_files: + model.set_metadata_prop( + "vivado_stitch_aximm_weights", json.dumps(self.aximm_weight_files) + ) tcl.append( ( "ipx::package_project -root_dir %s/ip -vendor %s " diff --git a/src/finn/transformation/fpgadataflow/loop_rolling.py b/src/finn/transformation/fpgadataflow/loop_rolling.py index 62287ccc2c..4782a74aed 100644 --- a/src/finn/transformation/fpgadataflow/loop_rolling.py +++ b/src/finn/transformation/fpgadataflow/loop_rolling.py @@ -254,12 +254,12 @@ def apply(self, model: ModelWrapper) -> Tuple[ModelWrapper, bool]: print("error: could not find metadata for node") exit(1) - node.metadata_props["pkg.torch.onnx.name_scopes"] = mnode.metadata_props[ - "pkg.torch.onnx.name_scopes" - ] - node.metadata_props["pkg.torch.onnx.class_hierarchy"] = mnode.metadata_props[ - "pkg.torch.onnx.class_hierarchy" - ] + node.metadata_props["pkg.torch.onnx.name_scopes"] = mnode.metadata_props.get( + "pkg.torch.onnx.name_scopes", "" + ) + node.metadata_props["pkg.torch.onnx.class_hierarchy"] = mnode.metadata_props.get( + "pkg.torch.onnx.class_hierarchy", "" + ) assert P.add_node(node) graph.sort() diff --git a/src/finn/transformation/fpgadataflow/set_fifo_depths.py b/src/finn/transformation/fpgadataflow/set_fifo_depths.py index fa317265a6..881ca2d7e6 100644 --- a/src/finn/transformation/fpgadataflow/set_fifo_depths.py +++ b/src/finn/transformation/fpgadataflow/set_fifo_depths.py @@ -314,7 +314,7 @@ def apply(self, model): "ElementwiseMul_rtl", "ElementwiseSub_rtl", ] - modified_mlo_nodes = [] + modified_mlo_nodes = {} for node in model.graph.node: # verify assumptions assert is_hls_node(node) or is_rtl_node(node), "Found non-fpgadataflow node: " + str( @@ -354,14 +354,47 @@ def apply(self, model): "Changed mem_mode from external to internal_decoupled for " + node.onnx_node.name ) - # do necessary temporary settings for mlo nodes + # do necessary temporary settings for external_mem nodes if node.onnx_node.op_type in mlo_optypes: mlo_max_iter = node.get_nodeattr("mlo_max_iter") - if mlo_max_iter: - modified_mlo_nodes.append(node.onnx_node.name) + has_mem_mode = "mem_mode" in node.get_nodeattr_types() + mmode = node.get_nodeattr("mem_mode") if has_mem_mode else None + if mlo_max_iter or mmode == "external_mem": + node_mlo_info = { + "orig_mem_mode": mmode, + "orig_mlo_max_iter": mlo_max_iter, + "saved_initializer": None, + } node.set_nodeattr("mlo_max_iter", 0) if node.onnx_node.op_type.startswith("MVAU"): node.set_nodeattr("mem_mode", "external") + # If the weight tensor has an initializer and is not + # already a graph input, we must promote it so that + # InsertFIFO / CreateStitchedIP treat it as a streaming + # input during FIFO sizing simulation. + param_input = node.onnx_node.input[1] + input_names_set = {inp.name for inp in model.graph.input} + if param_input not in input_names_set: + # Save and remove initializer + node_mlo_info["saved_initializer"] = model.get_initializer(param_input) + model.del_initializer(param_input) + # Move value_info to graph.input + param_vi = model.get_tensor_valueinfo(param_input) + if param_vi is not None and param_vi in model.graph.value_info: + model.graph.value_info.remove(param_vi) + else: + param_shape = model.get_tensor_shape(param_input) + param_vi = helper.make_tensor_value_info( + param_input, TensorProto.FLOAT, param_shape + ) + model.graph.input.append(param_vi) + # Ensure inFIFODepths covers the weight stream (index 1) + # since it was computed while mem_mode was external_mem + ifd = node.get_nodeattr("inFIFODepths") + if len(ifd) <= 1: + w_size = np.prod(node.get_folded_input_shape(1)[:-1]) + ifd.append(int(w_size) if w_size > 1 else 2) + node.set_nodeattr("inFIFODepths", ifd) elif ( node.onnx_node.op_type == "Thresholding_rtl" or node.onnx_node.op_type.startswith("Elementwise") @@ -382,6 +415,7 @@ def apply(self, model): # since we converted the parameter to an initializer if node.onnx_node.op_type.startswith("Elementwise"): node.set_nodeattr("rhs_style", "const") + modified_mlo_nodes[node.onnx_node.name] = node_mlo_info self.mlo_max_iter = mlo_max_iter reset_implementation(node) # insert stream infrastructure (DWC/FIFO) @@ -438,6 +472,7 @@ def apply(self, model): # Apply depths back into the model; # also set in/outFIFODepths to zero for non-FIFO # nodes, preventing further FIFO insertion + weight_fifos_to_remove = [] for node in model.graph.node: # set FIFO depth, reset FIFO implementation, # and set implementation/ram styles @@ -472,20 +507,60 @@ def apply(self, model): node_inst.set_nodeattr("mem_mode", "external") reset_implementation(node_inst) modified_extw_nodes.remove(node.name) - # do the same resetting for mlo nodes + # do the same resetting for mlo / external_mem nodes if node.op_type in mlo_optypes: if node.name in modified_mlo_nodes and node.op_type.startswith("MVAU"): node_inst = getCustomOp(node) - node_inst.set_nodeattr("mlo_max_iter", self.mlo_max_iter) - node_inst.set_nodeattr("mem_mode", "internal_decoupled") + node_mlo_info = modified_mlo_nodes[node.name] + node_inst.set_nodeattr("mlo_max_iter", node_mlo_info["orig_mlo_max_iter"]) + node_inst.set_nodeattr("mem_mode", node_mlo_info["orig_mem_mode"]) + # Remove the weight-stream FIFO that was inserted during + # FIFO sizing (input index 1) and restore the original + # weight tensor connection. + if node_mlo_info["saved_initializer"] is not None: + # node.input[1] now points to FIFO output; find the FIFO + # param_input = node.input[1] + weight_fifo_out = node.input[1] + weight_fifo = model.find_producer(weight_fifo_out) + if weight_fifo is not None and weight_fifo.op_type.startswith( + "StreamingFIFO" + ): + # The original weight tensor is the FIFO's input + orig_weight_name = weight_fifo.input[0] + # Reconnect MVAU directly to original weight tensor + node.input[1] = orig_weight_name + # Defer removal of the FIFO node until after iteration + weight_fifos_to_remove.append((weight_fifo, weight_fifo_out)) + else: + orig_weight_name = weight_fifo_out + # Restore initializer and demote weight from graph input + model.set_initializer( + orig_weight_name, node_mlo_info["saved_initializer"] + ) + for gi in list(model.graph.input): + if gi.name == orig_weight_name: + model.graph.input.remove(gi) + model.graph.value_info.append(gi) + break reset_implementation(node_inst) - modified_mlo_nodes.remove(node.name) + del modified_mlo_nodes[node.name] + + # Remove weight-stream FIFOs that were deferred during the loop + for weight_fifo, weight_fifo_out in weight_fifos_to_remove: + model.graph.node.remove(weight_fifo) + for vi in list(model.graph.value_info): + if vi.name == weight_fifo_out: + model.graph.value_info.remove(vi) + break + if weight_fifo.name in fifos: + del fifos[weight_fifo.name] sorted_ind_map = dict(sorted(self.ind_map.items(), key=lambda item: item[1])) for k, v in sorted_ind_map.items(): node = model.get_node_from_name(k) node_inst = getCustomOp(node) - node_inst.set_nodeattr("mlo_max_iter", self.mlo_max_iter) + node_mlo_info = modified_mlo_nodes[node.name] + node_inst.set_nodeattr("mlo_max_iter", node_mlo_info["orig_mlo_max_iter"]) # remove initializer again param_input = node.input[1] param_input_vi = model.get_tensor_valueinfo(param_input) @@ -496,7 +571,7 @@ def apply(self, model): # Restore rhs_style to "input" (it must have been "input" for MLO nodes) node_inst.set_nodeattr("rhs_style", "input") reset_implementation(node_inst) - modified_mlo_nodes.remove(node.name) + del modified_mlo_nodes[node.name] assert ( len(modified_extw_nodes) == 0 and len(fifos.keys()) == 0 diff --git a/src/finn/util/mlo_sim.py b/src/finn/util/mlo_sim.py index 9f906767f3..8a1f47abf0 100644 --- a/src/finn/util/mlo_sim.py +++ b/src/finn/util/mlo_sim.py @@ -31,6 +31,7 @@ # aximm simulation tasks for handling the aximm interfaces. import numpy as np +import os from qonnx.core.modelwrapper import ModelWrapper from qonnx.custom_op.registry import getCustomOp from typing import Callable @@ -81,10 +82,36 @@ def mlo_prehook_func_factory(node) -> Callable[[SimEngine], None]: if downstream.op_type.startswith("MVAU"): mvau_hbm_weights[idx] = {} mvau_hbm_weights[idx]["name"] = lb_inp.name - datfile = ( - f"{finnloop_op.get_nodeattr('code_gen_dir_ipgen')}/memblock_MVAU_rtl_id_{idx}.dat" - ) - mvau_hbm_weights[idx]["value"] = dat_file_to_numpy_array(datfile) + code_gen_dir = finnloop_op.get_nodeattr("code_gen_dir_ipgen") + npy_file = f"{code_gen_dir}/input1_MVAU_rtl_id_{idx}.npy" + datfile = f"{code_gen_dir}/memblock_MVAU_rtl_id_{idx}.dat" + mvau_op = getCustomOp(downstream) + mh = mvau_op.get_nodeattr("MH") + mw = mvau_op.get_nodeattr("MW") + wdt_width = mvau_op.get_input_datatype(1).bitwidth() + # Must match RTL LAYER_OFFS: align to AXI bus width (256 bits = 32 bytes) + axi_bytes = 32 + raw_layer_bytes = (mh * mw * wdt_width + 7) // 8 + layer_offs = (raw_layer_bytes + axi_bytes - 1) & ~(axi_bytes - 1) + if os.path.isfile(npy_file): + weight_npy = np.load(npy_file) + # Pack npy values into byte array for AXI-MM + # Memory byte order matches npy_to_rtlsim_input packing + tinner = weight_npy.shape[-1] + words_per_iter = raw_layer_bytes // tinner + flat = weight_npy.reshape(-1, tinner) + n_iters = len(flat) // words_per_iter + weight_bytes = [] + for it in range(n_iters): + for row in flat[it * words_per_iter : (it + 1) * words_per_iter]: + for val in row: + weight_bytes.append(int(val) & 0xFF) + # Pad to layer_offs boundary + pad = layer_offs - raw_layer_bytes + weight_bytes.extend([0] * pad) + mvau_hbm_weights[idx]["value"] = np.array(weight_bytes, dtype=np.uint8) + else: + mvau_hbm_weights[idx]["value"] = dat_file_to_numpy_array(datfile) mvau_hbm_weights[idx]["extern_idx"] = extern_idx mvau_hbm_weights[idx]["extern_name"] = f"m_axi_MVAU_id_{idx}" extern_idx += 1 diff --git a/tests/fpgadataflow/test_fpgadataflow_finnloop.py b/tests/fpgadataflow/test_fpgadataflow_finnloop.py index 860eb6321c..ec822caa68 100644 --- a/tests/fpgadataflow/test_fpgadataflow_finnloop.py +++ b/tests/fpgadataflow/test_fpgadataflow_finnloop.py @@ -1,5 +1,6 @@ import pytest +import glob import numpy as np import os import re @@ -100,6 +101,10 @@ def make_loop_modelwrapper( rhs_shape=[1], eltw_param_dtype="INT8", name_suffix="", + mvau_pe=2, + mvau_simd=2, + mvau_th=1, + helper_pe=2, ): is_float = eltw_param_dtype == "FLOAT32" @@ -152,7 +157,7 @@ def make_loop_modelwrapper( { "NumChannels": mh, "NumOutputStreams": 2, - "PE": 8, + "PE": helper_pe, "inputDataType": dtype.name, "outFIFODepths": [2, 2], "cpp_interface": "hls_vector", @@ -167,14 +172,16 @@ def make_loop_modelwrapper( { "MW": mw, "MH": mh, - "SIMD": 2, - "PE": 2, + "SIMD": mvau_simd, + "PE": mvau_pe, + "TH": mvau_th, "inputDataType": "INT8", "weightDataType": "INT8", "outputDataType": "INT32", "ActVal": 0, "binaryXnorMode": 0, "noActivation": 1, + "mem_mode": "external_mem", }, ), create_node( @@ -184,7 +191,7 @@ def make_loop_modelwrapper( f"Thresholding_rtl_0{name_suffix}", { "NumChannels": mh, - "PE": 2, + "PE": helper_pe, "inputDataType": "INT32", "weightDataType": "INT33", "outputDataType": dtype.name, @@ -200,14 +207,16 @@ def make_loop_modelwrapper( { "MW": mw, "MH": mh, - "SIMD": 2, - "PE": 2, + "SIMD": mvau_simd, + "PE": mvau_pe, + "TH": mvau_th, "inputDataType": "INT8", "weightDataType": "INT8", "outputDataType": "INT32", "ActVal": 0, "binaryXnorMode": 0, "noActivation": 1, + "mem_mode": "external_mem", }, ), create_node( @@ -217,7 +226,7 @@ def make_loop_modelwrapper( f"Thresholding_rtl_1{name_suffix}", { "NumChannels": mh, - "PE": 2, + "PE": helper_pe, "inputDataType": "INT32", "weightDataType": "INT33", "outputDataType": dtype.name, @@ -233,14 +242,16 @@ def make_loop_modelwrapper( { "MW": mw, "MH": mh, - "SIMD": 2, - "PE": 2, + "SIMD": mvau_simd, + "PE": mvau_pe, + "TH": mvau_th, "inputDataType": "INT8", "weightDataType": "INT8", "outputDataType": "INT32", "ActVal": 0, "binaryXnorMode": 0, "noActivation": 1, + "mem_mode": "external_mem", }, ), create_node( @@ -250,7 +261,7 @@ def make_loop_modelwrapper( f"Thresholding_rtl_2{name_suffix}", { "NumChannels": mh, - "PE": 2, + "PE": helper_pe, "inputDataType": "INT32", "weightDataType": "INT33", "outputDataType": dtype.name, @@ -272,7 +283,7 @@ def make_loop_modelwrapper( "out_dtype": "INT9", "lhs_style": "input", "rhs_style": "input", - "PE": 2, + "PE": helper_pe, }, ), create_node( @@ -323,7 +334,7 @@ def make_loop_modelwrapper( f"Thresholding_rtl4{name_suffix}", { "NumChannels": mh, - "PE": 4, + "PE": helper_pe, "numSteps": dtype.get_num_possible_values() - 1, "inputDataType": thresholding_input_dtype.name, "weightDataType": thresholding_input_dtype.name, @@ -410,8 +421,102 @@ def make_loop_modelwrapper( return loop_body_model +def make_single_mvau_loop_body( + mw, + mh, + dtype=DataType["INT8"], + name_suffix="", + mvau_pe=2, + mvau_simd=2, + mvau_th=1, + helper_pe=2, +): + """Create a minimal loop body with just MVAU_rtl -> Thresholding_rtl.""" + + W0 = gen_finn_dt_tensor(dtype, (mw, mh)) + T0 = np.sort( + generate_random_threshold_values(dtype, 1, dtype.get_num_possible_values() - 1), axis=1 + ) + + nodes = [ + create_node( + "MVAU_rtl", + [f"ifm{name_suffix}", f"weights0{name_suffix}"], + [f"mm0_out{name_suffix}"], + f"MVAU_rtl_0{name_suffix}", + { + "MW": mw, + "MH": mh, + "SIMD": mvau_simd, + "PE": mvau_pe, + "TH": mvau_th, + "inputDataType": "INT8", + "weightDataType": "INT8", + "outputDataType": "INT32", + "ActVal": 0, + "binaryXnorMode": 0, + "noActivation": 1, + "mem_mode": "external_mem", + }, + ), + create_node( + "Thresholding_rtl", + [f"mm0_out{name_suffix}", f"thresh0{name_suffix}"], + [f"ofm{name_suffix}"], + f"Thresholding_rtl_0{name_suffix}", + { + "NumChannels": mh, + "PE": helper_pe, + "inputDataType": "INT32", + "weightDataType": "INT33", + "outputDataType": dtype.name, + "ActVal": int(dtype.min()), + "numSteps": dtype.get_num_possible_values() - 1, + }, + ), + ] + + loop_body = helper.make_graph( + nodes=nodes, + name=f"single_mvau_graph{name_suffix}", + inputs=[ + create_tensor_info(f"ifm{name_suffix}", [1, 3, 3, mw]), + create_threshold(f"thresh0{name_suffix}", (1, dtype.get_num_possible_values() - 1)), + ], + outputs=[create_tensor_info(f"ofm{name_suffix}", (1, 3, 3, mh))], + value_info=[ + create_tensor_info(f"mm0_out{name_suffix}", [1, 3, 3, mh]), + ], + ) + + loop_body_model = qonnx_make_model(loop_body, producer_name=f"single-mvau-body{name_suffix}") + loop_body_model = ModelWrapper(loop_body_model) + + loop_body_model.set_initializer(f"weights0{name_suffix}", W0) + loop_body_model.set_initializer(f"thresh0{name_suffix}", T0) + + for tensor in [ + f"weights0{name_suffix}", + f"thresh0{name_suffix}", + f"ifm{name_suffix}", + f"ofm{name_suffix}", + ]: + loop_body_model.set_tensor_datatype(tensor, dtype) + + return loop_body_model + + def create_chained_loop_bodies( - mw, mh, num_copies, elemwise_optype="ElementwiseMul_hls", rhs_shape=[1], eltw_param_dtype="INT8" + mw, + mh, + num_copies, + elemwise_optype="ElementwiseMul_hls", + rhs_shape=[1], + eltw_param_dtype="INT8", + mvau_pe=2, + mvau_simd=2, + mvau_th=1, + helper_pe=2, ): loop_body_models = [] @@ -426,6 +531,10 @@ def create_chained_loop_bodies( rhs_shape=rhs_shape, eltw_param_dtype=eltw_param_dtype, name_suffix=name_suffix, + mvau_pe=mvau_pe, + mvau_simd=mvau_simd, + mvau_th=mvau_th, + helper_pe=helper_pe, ) loop_body_models.append(loop_body_model) @@ -622,6 +731,182 @@ def test_finnloop_end2end_mlo( ), f"Check vivado.log in {tmp_output_dir}/stitched_ip" +# iteration count, number of models chained together +@pytest.mark.parametrize("iteration", [3]) +# elementwise operation +@pytest.mark.parametrize("elemwise_optype", ["ElementwiseMul_hls"]) +# elementwise shape +@pytest.mark.parametrize("rhs_shape", [[1]]) +# eltwise param dtype +@pytest.mark.parametrize("eltw_param_dtype", ["INT8"]) +# tail node +@pytest.mark.parametrize("tail_node", [False]) +@pytest.mark.fpgadataflow +@pytest.mark.vivado +@pytest.mark.slow +def test_finnloop_end2end_mlo_tiled( + iteration, elemwise_optype, rhs_shape, eltw_param_dtype, tail_node +): + """End-to-end MLO test with tiled MVAUs (TH>1).""" + dim = 12 + mvau_pe = 6 + mvau_simd = 3 + mvau_th = 3 + helper_pe = 6 + + # Check vivado version + vivado_path = os.environ.get("XILINX_VIVADO") + match = re.search(r"\b(20\d{2})\.(1|2)\b", vivado_path) + year, minor = int(match.group(1)), int(match.group(2)) + if (year, minor) < (2024, 2): + pytest.skip("""At least Vivado version 2024.2 needed for MLO.""") + loop_body_models = create_chained_loop_bodies( + dim, + dim, + iteration, + elemwise_optype, + rhs_shape, + eltw_param_dtype, + mvau_pe=mvau_pe, + mvau_simd=mvau_simd, + mvau_th=mvau_th, + helper_pe=helper_pe, + ) + model = loop_body_models[0] + for m in loop_body_models[1:]: + model = model.transform(MergeONNXModels(m)) + + if tail_node: + tail_outp = create_tensor_info("tail_outp", [1, 3, 3, dim]) + tr_node = create_node( + "ElementwiseAdd_hls", + [model.graph.output[0].name, "tail_add"], + ["tail_outp"], + "Add_tail", + { + "lhs_shape": [1, 3, 3, dim], + "rhs_shape": [1], + "out_shape": [1, 3, 3, dim], + "lhs_dtype": "INT8", + "rhs_dtype": "INT8", + "out_dtype": "INT9", + }, + ) + model.graph.node.insert(len(model.graph.node), tr_node) + model.graph.value_info.append(model.graph.output[0]) + model.graph.output.pop(0) + model.graph.output.append(tail_outp) + AddtailParam = gen_finn_dt_tensor(DataType["INT8"], [1]) + model.set_initializer("tail_add", AddtailParam) + model.set_tensor_datatype("tail_add", DataType["INT8"]) + + # cleanup + model = model.transform(RemoveUnusedTensors()) + model = model.transform(InferShapes()) + model = model.transform(InferDataTypes()) + + # Generate reference by first copying the model and running cppsim + model_ref = model.transform(PrepareCppSim()) + model_ref = model_ref.transform(CompileCppSim()) + model_ref = model_ref.transform(SetExecMode("cppsim")) + + # generate reference io pair + x = gen_finn_dt_tensor(DataType["INT8"], (1, 3, 3, dim)) + io_dict = {model_ref.graph.input[0].name: x} + y_dict = oxe.execute_onnx(model_ref, io_dict) + y_ref = y_dict[model_ref.graph.output[0].name] + + tmp_output_dir = make_build_dir("build_mlo_tiled") + + np.save(tmp_output_dir + "/input.npy", x) + np.save(tmp_output_dir + "/expected_output.npy", y_ref) + + model.save(tmp_output_dir + "/mlo_model.onnx") + + # steps - skip step_target_fps_parallelization since PE/SIMD/TH already set + steps = [ + "step_create_dataflow_partition", + "step_loop_rolling", + "step_apply_folding_config", + "step_minimize_bit_width", + "step_generate_estimate_reports", + "step_hw_codegen", + "step_hw_ipgen", + "step_set_fifo_depths", + "step_create_stitched_ip", + ] + + cfg = build_cfg.DataflowBuildConfig( + output_dir=tmp_output_dir, + steps=steps, + target_fps=1000, + synth_clk_period_ns=10.0, + board="V80", + rtlsim_batch_size=100, + standalone_thresholds=True, + mlo=True, + loop_body_hierarchy=[["", "layers.0"]], + loop_body_range=(model.graph.node[0], model.graph.node[9]), + verify_steps=verif_steps, + verify_input_npy=tmp_output_dir + "/input.npy", + verify_expected_output_npy=tmp_output_dir + "/expected_output.npy", + verify_save_rtlsim_waveforms=True, + generate_outputs=[ + build_cfg.DataflowOutputType.ESTIMATE_REPORTS, + build_cfg.DataflowOutputType.STITCHED_IP, + ], + ) + build.build_dataflow_cfg(tmp_output_dir + "/mlo_model.onnx", cfg) + + # Dump weight files for hardware debug + built_model = ModelWrapper(tmp_output_dir + "/mlo_model.onnx") + for node in built_model.graph.node: + if node.op_type == "FINNLoop": + fl_op = getCustomOp(node) + code_gen_dir = fl_op.get_nodeattr("code_gen_dir_ipgen") + if code_gen_dir and os.path.isdir(code_gen_dir): + for f in sorted(glob.glob(code_gen_dir + "/input1_*.npy")): + arr = np.load(f) + base = os.path.basename(f).replace(".npy", "") + txt_path = tmp_output_dir + f"/weights_{base}.txt" + with open(txt_path, "w") as tf: + tf.write(f"# {f}\n# shape: {arr.shape}, dtype: {arr.dtype}\n") + tf.write("# row | decimal_values | hex_bytes | bus_word\n\n") + flat = arr.reshape(-1, arr.shape[-1]) + for i, row in enumerate(flat): + dec = " ".join(f"{int(v):5d}" for v in row) + hx = " ".join(f"{int(v) & 0xFF:02x}" for v in row) + bus = 0 + for j, v in enumerate(row): + bus |= (int(v) & 0xFF) << (j * 8) + tf.write(f"[{i:3d}] {dec} | {hx} | 0x{bus:0{arr.shape[-1]*2}x}\n") + print(f"DEBUG: weight dump -> {txt_path}") + for f in sorted(glob.glob(code_gen_dir + "/memblock_*.dat")): + base = os.path.basename(f).replace(".dat", "") + txt_path = tmp_output_dir + f"/weights_{base}.txt" + with open(txt_path, "w") as tf: + tf.write(f"# {f}\n") + with open(f) as df: + for i, line in enumerate(df): + tf.write(f"[{i:3d}] {line.strip()}\n") + print(f"DEBUG: dat dump -> {txt_path}") + + # check if expected files are there + assert os.path.isfile(tmp_output_dir + "/loop-body-template.onnx") + assert os.path.isfile(tmp_output_dir + "/stitched_ip/ip/component.xml") + + verif_dir = tmp_output_dir + "/verification_output" + assert os.path.isfile( + verif_dir + "/verify_folded_hls_cppsim_0_SUCCESS.npy" + ), f"Check npy files in {verif_dir}" + assert os.path.isfile( + verif_dir + "/verify_node_by_node_rtlsim_0_SUCCESS.npy" + ), f"Check npy files in {verif_dir}" + assert os.path.isfile( + verif_dir + "/verify_stitched_ip_rtlsim_0_SUCCESS.npy" + ), f"Check npy files in {verif_dir}" + + # Debug test for manual loop transformation steps below # This test is intentionally not marked for CI # Use test_finnloop_end2end_mlo instead diff --git a/tests/fpgadataflow/test_fpgadataflow_mvau.py b/tests/fpgadataflow/test_fpgadataflow_mvau.py index 4128092df1..378ad17b22 100644 --- a/tests/fpgadataflow/test_fpgadataflow_mvau.py +++ b/tests/fpgadataflow/test_fpgadataflow_mvau.py @@ -857,14 +857,111 @@ def test_fpgadataflow_rtl_mvau( ).all(), "Output of ONNX model not matching output of stitched-IP RTL model!" +@pytest.mark.parametrize("mh", [12]) +@pytest.mark.parametrize("mw", [12]) +@pytest.mark.parametrize("pe", [6]) +@pytest.mark.parametrize("simd", [3]) +@pytest.mark.parametrize("th", [3]) +@pytest.mark.parametrize("idt_wdt", [[DataType["UINT8"], DataType["INT8"]]]) +@pytest.mark.parametrize("clk_ns", [4]) +@pytest.mark.fpgadataflow +@pytest.mark.slow +@pytest.mark.vivado +def test_fpgadataflow_rtl_tiled_mvau(mh, mw, pe, simd, th, idt_wdt, clk_ns): + # Tiled MVAU only supported on Versal (DSP58) + part = "xcvc1902-vsva2197-2MP-e-S" + + if (pe * simd) % th != 0: + pytest.skip("(PE * SIMD) must be divisible by TH") + + if mw % simd != 0: + pytest.skip("MW must be divisible by SIMD") + + if mh % pe != 0: + pytest.skip("MH must be divisible by PE") + + idt, wdt = idt_wdt + # Create test input vector (produced by SWG) + ofm_shape = (3, 3) + ofm_h, ofm_w = ofm_shape + ifm = helper.make_tensor_value_info("ifm", TensorProto.FLOAT, [1, ofm_h, ofm_w, mw]) + ofm = helper.make_tensor_value_info("ofm", TensorProto.FLOAT, (1, ofm_h, ofm_w, mh)) + W = gen_finn_dt_tensor(wdt, (mw, mh)) + model = make_single_matmul_modelwrapper(ifm, ofm, idt, wdt, W) + model = model.transform(GiveUniqueNodeNames()) + model = model.transform(GiveReadableTensorNames()) + + # Create MatMul & obtain golden reference output + A = gen_finn_dt_tensor( + model.get_tensor_datatype("global_in"), model.get_tensor_shape("global_in") + ) + input_dict = prepare_inputs(A, idt, wdt, inp_name="global_in") + + # Execute ONNX model + output_matmul = oxe.execute_onnx(model, input_dict)["global_out"] + + # Create MVAU + model = model.transform(to_hw.InferQuantizedMatrixVectorActivation()) + model = model.transform(GiveUniqueNodeNames()) + + # Apply convert-to-rtl step + model = model.transform(SpecializeLayers(part)) + model = model.transform(GiveUniqueNodeNames()) + + assert model.graph.node[0].op_type == "MVAU_rtl" + # Apply folding with TH for tiled implementation + folding_config = { + "Defaults": {}, + "MVAU_rtl_0": { + "PE": pe, + "SIMD": simd, + "TH": th, + "resType": "dsp", + "mem_mode": "external_mem", + }, + } + model = model.transform(ApplyConfig(folding_config)) + model = model.transform(MinimizeWeightBitWidth()) + model = model.transform(MinimizeAccumulatorWidth()) + # make sure the changed datatypes are propagated through the network + model = model.transform(InferDataTypes()) + + # Run CPPsim + model = model.transform(SetExecMode("cppsim")) + model = model.transform(PrepareCppSim()) + model = model.transform(CompileCppSim()) + output_mvau_hls = oxe.execute_onnx(model, input_dict)["global_out"] + assert ( + output_matmul == output_mvau_hls + ).all(), "Output of ONNX model not matching output of node-by-node CPPsim!" + + # Run node-by-node RTLsim + model = model.transform(SetExecMode("rtlsim")) + model = model.transform(PrepareIP(part, clk_ns)) + model = model.transform(HLSSynthIP()) + model = model.transform(PrepareRTLSim()) + output_mvau_rtl = oxe.execute_onnx(model, input_dict)["global_out"] + assert ( + output_matmul == output_mvau_rtl + ).all(), "Output of ONNX model not matching output of node-by-node RTLsim!" + + # Run stitched-ip RTLsim + model = model.transform(InsertAndSetFIFODepths(part, clk_ns)) + model = model.transform(PrepareIP(part, clk_ns)) + model = model.transform(HLSSynthIP()) + model = model.transform(CreateStitchedIP(part, clk_ns)) + output_mvau_rtl_stitch = oxe.execute_onnx(model, input_dict)["global_out"] + assert ( + output_matmul == output_mvau_rtl_stitch + ).all(), "Output of ONNX model not matching output of tiled stitched-IP RTL model!" + + @pytest.mark.parametrize("mh", [32]) @pytest.mark.parametrize("mw", [16]) @pytest.mark.parametrize("n_vectors", [32]) @pytest.mark.parametrize("pe", [1, 16, 32]) @pytest.mark.parametrize("simd", [1, 8, 16]) -@pytest.mark.parametrize( - "idt_wdt", [[DataType["INT8"], DataType["INT8"]], [DataType["INT4"], DataType["INT4"]]] -) +@pytest.mark.parametrize("idt_wdt", [[DataType["INT4"], DataType["INT4"]]]) @pytest.mark.parametrize( "part", ["xcvc1902-vsva2197-2MP-e-S", "xcku3p-ffva676-1-e", "xc7z020clg400-1"] )