[bugfix] fix gdn sharded_state_dict lora#23
[bugfix] fix gdn sharded_state_dict lora#23Jintao-Huang wants to merge 2 commits intomodelscope:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements the sharded_state_dict method in the GatedDeltaNet module to support distributed checkpointing. The implementation includes logic for sharding parameters and submodules, specifically handling conv1d and in_proj layers with tensor parallel sharding and tensor splitting. A critical issue was identified regarding the use of an undefined attribute self.conv_dim_local_tp in assertions, which should be replaced with a locally calculated variable.
| conv_layer_name_list = ['conv1d.weight'] | ||
| assert (sharded_state_dict[f'{prefix}conv1d.weight'].data.size(0) == self.conv_dim_local_tp), ( | ||
| self.conv_dim_local_tp, sharded_state_dict[f'{prefix}conv1d.weight']) | ||
| if self.conv_bias: | ||
| conv_layer_name_list.append('conv1d.bias') | ||
| assert (sharded_state_dict[f'{prefix}conv1d.bias'].data.size(0) == self.conv_dim_local_tp), ( | ||
| self.conv_dim_local_tp, sharded_state_dict[f'{prefix}conv1d.bias']) |
There was a problem hiding this comment.
The attribute self.conv_dim_local_tp is used in assertions but is not defined within the class or the method. It should be defined as a local variable, similar to in_proj_dim_local_tp on line 364, representing the local dimension of the convolution layers after tensor parallel sharding. Based on the forward pass logic, this dimension is (2 * self.qk_dim + self.v_dim) // self.tp_size.
| conv_layer_name_list = ['conv1d.weight'] | |
| assert (sharded_state_dict[f'{prefix}conv1d.weight'].data.size(0) == self.conv_dim_local_tp), ( | |
| self.conv_dim_local_tp, sharded_state_dict[f'{prefix}conv1d.weight']) | |
| if self.conv_bias: | |
| conv_layer_name_list.append('conv1d.bias') | |
| assert (sharded_state_dict[f'{prefix}conv1d.bias'].data.size(0) == self.conv_dim_local_tp), ( | |
| self.conv_dim_local_tp, sharded_state_dict[f'{prefix}conv1d.bias']) | |
| conv_dim_local_tp = (2 * self.qk_dim + self.v_dim) // self.tp_size | |
| conv_layer_name_list = ['conv1d.weight'] | |
| assert (sharded_state_dict[f'{prefix}conv1d.weight'].data.size(0) == conv_dim_local_tp), ( | |
| conv_dim_local_tp, sharded_state_dict[f'{prefix}conv1d.weight']) | |
| if self.conv_bias: | |
| conv_layer_name_list.append('conv1d.bias') | |
| assert (sharded_state_dict[f'{prefix}conv1d.bias'].data.size(0) == conv_dim_local_tp), ( | |
| conv_dim_local_tp, sharded_state_dict[f'{prefix}conv1d.bias']) |
No description provided.