Skip to content

[bugfix] fix gdn sharded_state_dict lora#23

Open
Jintao-Huang wants to merge 2 commits intomodelscope:mainfrom
Jintao-Huang:fix_gdn_sharded_state_dict_lora
Open

[bugfix] fix gdn sharded_state_dict lora#23
Jintao-Huang wants to merge 2 commits intomodelscope:mainfrom
Jintao-Huang:fix_gdn_sharded_state_dict_lora

Conversation

@Jintao-Huang
Copy link
Copy Markdown
Collaborator

No description provided.

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements the sharded_state_dict method in the GatedDeltaNet module to support distributed checkpointing. The implementation includes logic for sharding parameters and submodules, specifically handling conv1d and in_proj layers with tensor parallel sharding and tensor splitting. A critical issue was identified regarding the use of an undefined attribute self.conv_dim_local_tp in assertions, which should be replaced with a locally calculated variable.

Comment on lines +384 to +390
conv_layer_name_list = ['conv1d.weight']
assert (sharded_state_dict[f'{prefix}conv1d.weight'].data.size(0) == self.conv_dim_local_tp), (
self.conv_dim_local_tp, sharded_state_dict[f'{prefix}conv1d.weight'])
if self.conv_bias:
conv_layer_name_list.append('conv1d.bias')
assert (sharded_state_dict[f'{prefix}conv1d.bias'].data.size(0) == self.conv_dim_local_tp), (
self.conv_dim_local_tp, sharded_state_dict[f'{prefix}conv1d.bias'])
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The attribute self.conv_dim_local_tp is used in assertions but is not defined within the class or the method. It should be defined as a local variable, similar to in_proj_dim_local_tp on line 364, representing the local dimension of the convolution layers after tensor parallel sharding. Based on the forward pass logic, this dimension is (2 * self.qk_dim + self.v_dim) // self.tp_size.

Suggested change
conv_layer_name_list = ['conv1d.weight']
assert (sharded_state_dict[f'{prefix}conv1d.weight'].data.size(0) == self.conv_dim_local_tp), (
self.conv_dim_local_tp, sharded_state_dict[f'{prefix}conv1d.weight'])
if self.conv_bias:
conv_layer_name_list.append('conv1d.bias')
assert (sharded_state_dict[f'{prefix}conv1d.bias'].data.size(0) == self.conv_dim_local_tp), (
self.conv_dim_local_tp, sharded_state_dict[f'{prefix}conv1d.bias'])
conv_dim_local_tp = (2 * self.qk_dim + self.v_dim) // self.tp_size
conv_layer_name_list = ['conv1d.weight']
assert (sharded_state_dict[f'{prefix}conv1d.weight'].data.size(0) == conv_dim_local_tp), (
conv_dim_local_tp, sharded_state_dict[f'{prefix}conv1d.weight'])
if self.conv_bias:
conv_layer_name_list.append('conv1d.bias')
assert (sharded_state_dict[f'{prefix}conv1d.bias'].data.size(0) == conv_dim_local_tp), (
conv_dim_local_tp, sharded_state_dict[f'{prefix}conv1d.bias'])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant