From 992686c7ae97013a420fc10b544cbeb361890f4b Mon Sep 17 00:00:00 2001 From: Logan Hallee Date: Fri, 11 Oct 2024 16:29:38 -0400 Subject: [PATCH] out_proj dimension fix --- Diff-Transformer/multihead_diffattn.py | 2 +- Diff-Transformer/multihead_flashdiff_1.py | 2 +- Diff-Transformer/multihead_flashdiff_2.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Diff-Transformer/multihead_diffattn.py b/Diff-Transformer/multihead_diffattn.py index f33bdf134..75c1cee66 100644 --- a/Diff-Transformer/multihead_diffattn.py +++ b/Diff-Transformer/multihead_diffattn.py @@ -52,7 +52,7 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False) self.k_proj = nn.Linear(embed_dim, embed_dim // self.n_rep, bias=False) self.v_proj = nn.Linear(embed_dim, embed_dim // self.n_rep, bias=False) - self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False) + self.out_proj = nn.Linear(self.num_heads * 2 * self.head_dim, embed_dim, bias=False) self.lambda_init = lambda_init_fn(depth) self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1)) diff --git a/Diff-Transformer/multihead_flashdiff_1.py b/Diff-Transformer/multihead_flashdiff_1.py index 0bcdd162c..98107820b 100644 --- a/Diff-Transformer/multihead_flashdiff_1.py +++ b/Diff-Transformer/multihead_flashdiff_1.py @@ -57,7 +57,7 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False) self.k_proj = nn.Linear(embed_dim, embed_dim // self.n_rep, bias=False) self.v_proj = nn.Linear(embed_dim, embed_dim // self.n_rep, bias=False) - self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False) + self.out_proj = nn.Linear(self.num_heads * 2 * self.head_dim, embed_dim, bias=False) self.lambda_init = lambda_init_fn(depth) self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1)) diff --git a/Diff-Transformer/multihead_flashdiff_2.py b/Diff-Transformer/multihead_flashdiff_2.py index c4f5afd5f..315c71bf3 100644 --- a/Diff-Transformer/multihead_flashdiff_2.py +++ b/Diff-Transformer/multihead_flashdiff_2.py @@ -56,7 +56,7 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False) self.k_proj = nn.Linear(embed_dim, embed_dim // self.n_rep, bias=False) self.v_proj = nn.Linear(embed_dim, embed_dim // self.n_rep, bias=False) - self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False) + self.out_proj = nn.Linear(self.num_heads * 2 * self.head_dim, embed_dim, bias=False) self.lambda_init = lambda_init_fn(depth) self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))