Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions finn-rtllib/dwc/hdl/axis_dwc.sv
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved.
//
// This file is subject to the Xilinx Design License Agreement located
// in the LICENSE.md file in the root directory of this repository.
//
// This file contains confidential and proprietary information of Xilinx, Inc.
// and is protected under U.S. and international copyright and other
// intellectual property laws.
//
// DISCLAIMER
// This disclaimer is not a license and does not grant any rights to the materials
// distributed herewith. Except as otherwise provided in a valid license issued to
// you by Xilinx, and to the maximum extent permitted by applicable law: (1) THESE
// MATERIALS ARE MADE AVAILABLE "AS IS" AND WITH ALL FAULTS, AND XILINX HEREBY
// DISCLAIMS ALL WARRANTIES AND CONDITIONS, EXPRESS, IMPLIED, OR STATUTORY,
// INCLUDING BUT NOT LIMITED TO WARRANTIES OF MERCHANTABILITY, NONINFRINGEMENT, OR
// FITNESS FOR ANY PARTICULAR PURPOSE; and (2) Xilinx shall not be liable (whether
// in contract or tort, including negligence, or under any other theory of
// liability) for any loss or damage of any kind or nature related to, arising
// under or in connection with these materials, including for any direct, or any
// indirect, special, incidental, or consequential loss or damage (including loss
// of data, profits, goodwill, or any type of loss or damage suffered as a result
// of any action brought by a third party) even if such damage or loss was
// reasonably foreseeable or Xilinx had been advised of the possibility of the
// same.
//
// CRITICAL APPLICATIONS
// Xilinx products are not designed or intended to be fail-safe, or for use in
// any application requiring failsafe performance, such as life-support or safety
// devices or systems, Class III medical devices, nuclear facilities, applications
// related to the deployment of airbags, or any other applications that could lead
// to death, personal injury, or severe property or environmental damage
// (individually and collectively, "Critical Applications"). Customer assumes the
// sole risk and liability of any use of Xilinx products in Critical Applications,
// subject only to applicable laws and regulations governing limitations on product
// liability.
//
// THIS COPYRIGHT NOTICE AND DISCLAIMER MUST BE RETAINED AS PART OF THIS FILE AT ALL TIMES.

module axis_dwc #(
parameter integer DEPTH = 512,
parameter integer S_DATA_BITS = 32,
parameter integer M_DATA_BITS = 8
) (
input logic aclk,
input logic aresetn,

input logic s_axis_tvalid,
output logic s_axis_tready,
input logic [S_DATA_BITS-1:0] s_axis_tdata,
input logic [S_DATA_BITS/8-1:0] s_axis_tkeep,
input logic s_axis_tlast,

output logic m_axis_tvalid,
input logic m_axis_tready,
output logic [M_DATA_BITS-1:0] m_axis_tdata,
output logic [M_DATA_BITS/8-1:0] m_axis_tkeep,
output logic m_axis_tlast
);

axis_fifo_adapter #(
.DEPTH(DEPTH),
.S_DATA_WIDTH(S_DATA_BITS),
.M_DATA_WIDTH(M_DATA_BITS)
) inst_fifo_adapter (
.clk (aclk),
.rst (~aresetn),

.s_axis_tdata (s_axis_tdata),
.s_axis_tkeep (s_axis_tkeep),
.s_axis_tvalid (s_axis_tvalid),
.s_axis_tready (s_axis_tready),
.s_axis_tlast (s_axis_tlast),
.s_axis_tid ('0),
.s_axis_tdest ('0),
.s_axis_tuser ('0),

.pause_req('0),

.m_axis_tdata (m_axis_tdata),
.m_axis_tkeep (m_axis_tkeep),
.m_axis_tvalid (m_axis_tvalid),
.m_axis_tready (m_axis_tready),
.m_axis_tlast (m_axis_tlast),
.m_axis_tid (),
.m_axis_tdest (),
.m_axis_tuser ()
);

endmodule
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
module fetch_weights #(
int unsigned PE,
int unsigned SIMD,
int unsigned TH = 1,
int unsigned MH,
int unsigned MW,
int unsigned N_REPS,
Expand All @@ -45,18 +46,22 @@ module fetch_weights #(

int unsigned N_LAYERS,

int unsigned EN_MLO = 1,

int unsigned QDEPTH = 8,
int unsigned EN_OREG = 1,
int unsigned N_DCPL_STGS = 1,
int unsigned DBG = 0,

// Safely deducible parameters
int unsigned DS_BITS_BA = (SIMD*WEIGHT_WIDTH+7)/8 * 8,
int unsigned WS_BITS_BA = (PE*SIMD*WEIGHT_WIDTH+7)/8 * 8,
logic[ADDR_BITS-1:0] LAYER_OFFS = ((MH*MW*WEIGHT_WIDTH+7)/8) & ~7 // 8-byte aligned
int unsigned IWSIMD = (TH > 1) ? ((PE*SIMD)/TH) : SIMD,
int unsigned OWSIMD = (PE * SIMD) / TH,
int unsigned DS_BITS_BA = (IWSIMD*WEIGHT_WIDTH+7)/8 * 8,
int unsigned WS_BITS_BA = (OWSIMD*WEIGHT_WIDTH+7)/8 * 8,
logic[ADDR_BITS-1:0] LAYER_OFFS = ((MH*MW*WEIGHT_WIDTH+7)/8 + (DATA_BITS/8-1)) & ~(DATA_BITS/8-1) // AXI bus-width aligned
) (
input logic aclk,
input logic aresetn,
input wire aclk,
input wire aresetn,

output logic m_done,

Expand Down Expand Up @@ -102,46 +107,143 @@ module fetch_weights #(
output logic s_idx_tready,
input logic[IDX_BITS-1:0] s_idx_tdata,

// DMA stream out (to external width converter)
output logic axis_dma_tvalid,
input logic axis_dma_tready,
output logic[DATA_BITS-1:0] axis_dma_tdata,
output logic[DATA_BITS/8-1:0] axis_dma_tkeep,
output logic axis_dma_tlast,

// DWC stream in (from external width converter)
input logic axis_dwc_tvalid,
output logic axis_dwc_tready,
input logic[DS_BITS_BA-1:0] axis_dwc_tdata,
input logic[(DS_BITS_BA)/8-1:0] axis_dwc_tkeep,
input logic axis_dwc_tlast,

// Stream
// TODO: Should we reg this? Would be quite wide ...
output logic m_axis_tvalid,
input logic m_axis_tready,
output logic[WS_BITS_BA-1:0] m_axis_tdata
);

localparam int unsigned WMAT_SIZE = ((MH*MW*WEIGHT_WIDTH+7)/8) & ~7;

// Offsets
logic [N_LAYERS-1:0][ADDR_BITS-1:0] l_offsets;
for(genvar i = 0; i < N_LAYERS; i++) begin
assign l_offsets[i] = (i * LAYER_OFFS);
end

logic q_idx_out_tvalid, q_idx_out_tready;
logic [IDX_BITS-1:0] q_idx_out_tdata;
logic [ADDR_BITS-1:0] q_dma_addr;
logic [LEN_BITS-1:0] q_dma_len;

// Queues
Q_srl #(
.depth(QDEPTH),
.width(IDX_BITS)
) inst_queue_in (
.clock(aclk), .reset(!aresetn),
.count(), .maxcount(),
.i_d(s_idx_tdata), .i_v(s_idx_tvalid), .i_r(s_idx_tready),
.o_d(q_idx_out_tdata), .o_v(q_idx_out_tvalid), .o_r(q_idx_out_tready)
);
//
// Indexes and DMA
//

logic dma_tvalid;
logic dma_tready;
logic [ADDR_BITS-1:0] dma_addr;
logic [LEN_BITS-1:0] dma_len;

if(TH > 1) begin

// Consts
localparam integer REPS_BITS = (N_REPS == 1) ? 1 : $clog2(N_REPS);

// Reps
typedef enum logic[0:0] {ST_IDLE, ST_DMA} state_t;
state_t state_C = ST_IDLE, state_N;

logic [REPS_BITS-1:0] cnt_dma_C = '0, cnt_dma_N;
logic [IDX_BITS-1:0] idx_C = '0, idx_N;

logic q_idx_out_tvalid, q_idx_out_tready;
logic [IDX_BITS-1:0] q_idx_out_tdata;

// Idx queue
Q_srl #(
.depth(QDEPTH),
.width(IDX_BITS)
) inst_queue_in (
.clock(aclk), .reset(!aresetn),
.count(), .maxcount(),
.i_d(s_idx_tdata), .i_v(s_idx_tvalid), .i_r(s_idx_tready),
.o_d(q_idx_out_tdata), .o_v(q_idx_out_tvalid), .o_r(q_idx_out_tready)
);

assign dma_addr = l_offsets[idx_C];
assign dma_len = ((MH*MW*WEIGHT_WIDTH+7)/8) & ~7;

always_ff @( posedge aclk ) begin: REG
if(~aresetn) begin
state_C <= ST_IDLE;

cnt_dma_C <= '0;
idx_C <= 'X;
end else begin
state_C <= state_N;

cnt_dma_C <= cnt_dma_N;
idx_C <= idx_N;
end
end

always_comb begin: NSL
state_N = state_C;

case (state_C)
ST_IDLE:
state_N = q_idx_out_tvalid ? ST_DMA : ST_IDLE;

ST_DMA:
state_N = (cnt_dma_C == N_REPS-1) && dma_tready ? ST_IDLE : ST_DMA;

assign q_dma_addr = l_offsets[q_idx_out_tdata];
assign q_dma_len = WMAT_SIZE;
endcase
end

// DMA
logic axis_dma_tvalid;
logic axis_dma_tready;
logic[DATA_BITS-1:0] axis_dma_tdata;
logic[DATA_BITS/8-1:0] axis_dma_tkeep;
logic axis_dma_tlast;
always_comb begin: DP
cnt_dma_N = cnt_dma_C;
idx_N = idx_C;

q_idx_out_tready = 1'b0;
dma_tvalid = 1'b0;

case (state_C)
ST_IDLE: begin
q_idx_out_tready = 1'b1;
cnt_dma_N = 0;
if(q_idx_out_tvalid) begin
idx_N = q_idx_out_tdata;
end
end

ST_DMA: begin
dma_tvalid = 1'b1;
if(dma_tready) begin
cnt_dma_N = cnt_dma_C + 1;
end
end

endcase
end

end else begin

// Idx queue
logic [IDX_BITS-1:0] q_idx_out_tdata;

Q_srl #(
.depth(QDEPTH),
.width(IDX_BITS)
) inst_idx_queue (
.clock(aclk), .reset(!aresetn),
.count(), .maxcount(),
.i_d(s_idx_tdata), .i_v(s_idx_tvalid), .i_r(s_idx_tready),
.o_d(q_idx_out_tdata), .o_v(dma_tvalid), .o_r(dma_tready)
);

assign dma_addr = l_offsets[q_idx_out_tdata];
assign dma_len = ((MH*MW*WEIGHT_WIDTH+7)/8) & ~7;

end

cdma_u_rd #(
.DATA_BITS(DATA_BITS),
Expand All @@ -150,8 +252,8 @@ cdma_u_rd #(
) inst_dma (
.aclk(aclk), .aresetn(aresetn),

.rd_valid(q_idx_out_tvalid), .rd_ready(q_idx_out_tready),
.rd_paddr(q_dma_addr), .rd_len(q_dma_len),
.rd_valid(dma_tvalid), .rd_ready(dma_tready),
.rd_paddr(dma_addr), .rd_len(dma_len),
.rd_done(m_done),

.m_axi_ddr_arvalid(m_axi_ddr_arvalid),
Expand All @@ -178,42 +280,32 @@ cdma_u_rd #(
.m_axis_ddr_tlast(axis_dma_tlast)
);

// Width conversion
logic axis_dwc_tvalid;
logic axis_dwc_tready;
logic[DS_BITS_BA-1:0] axis_dwc_tdata;
logic[(DS_BITS_BA)/8-1:0] axis_dwc_tkeep;
logic axis_dwc_tlast;

axis_fifo_adapter #(
.S_DATA_WIDTH(DATA_BITS), .M_DATA_WIDTH(DS_BITS_BA)
) inst_dwc (
.clk(aclk), .rst(~aresetn),
.pause_req('0), .s_axis_tid('0), .s_axis_tdest('0), .s_axis_tuser('0),
.s_axis_tvalid(axis_dma_tvalid), .s_axis_tready(axis_dma_tready), .s_axis_tdata(axis_dma_tdata), .s_axis_tkeep(axis_dma_tkeep), .s_axis_tlast(axis_dma_tlast),
.pause_ack(), .m_axis_tid(), .m_axis_tdest(), .m_axis_tuser(),
.m_axis_tvalid(axis_dwc_tvalid), .m_axis_tready(axis_dwc_tready), .m_axis_tdata(axis_dwc_tdata), .m_axis_tkeep(axis_dwc_tkeep), .m_axis_tlast(axis_dwc_tlast)
);

// Double buffer
// Local weight buffer
// Only for non-tiled nodes
logic axis_lwb_tvalid;
logic axis_lwb_tready;
logic[WS_BITS_BA-1:0] axis_lwb_tdata;

local_weight_buffer #(
.PE(PE), .SIMD(SIMD), .MH(MH), .MW(MW), .N_REPS(N_REPS), .WEIGHT_WIDTH(WEIGHT_WIDTH), .DBG(DBG)
) inst_weight_buff (
.clk(aclk), .rst(~aresetn),
.ivld(axis_dwc_tvalid), .irdy(axis_dwc_tready), .idat(axis_dwc_tdata),
.ovld(axis_lwb_tvalid), .ordy(axis_lwb_tready), .odat(axis_lwb_tdata)
);
if(TH == 1) begin
local_weight_buffer #(
.PE(PE), .SIMD(SIMD), .MH(MH), .MW(MW), .N_REPS(N_REPS), .WEIGHT_WIDTH(WEIGHT_WIDTH), .DBG(DBG)
) inst_weight_buff (
.clk(aclk), .rst(~aresetn),
.ivld(axis_dwc_tvalid), .irdy(axis_dwc_tready), .idat(axis_dwc_tdata),
.ovld(axis_lwb_tvalid), .ordy(axis_lwb_tready), .odat(axis_lwb_tdata)
);
end else begin
assign axis_lwb_tvalid = axis_dwc_tvalid;
assign axis_dwc_tready = axis_lwb_tready;
assign axis_lwb_tdata = axis_dwc_tdata;
end

// Reg slice
if(EN_OREG) begin
skid #(
.DATA_WIDTH(WS_BITS_BA), .FEED_STAGES(N_DCPL_STGS)
) inst_oreg (
.clk(aclk), .rst(~aresetn),
.clk(aclk), .rst(!aresetn),
.ivld(axis_lwb_tvalid), .irdy(axis_lwb_tready), .idat(axis_lwb_tdata),
.ovld(m_axis_tvalid), .ordy(m_axis_tready), .odat(m_axis_tdata)
);
Expand Down
Loading
Loading