Skip to content

Commit 850e628

Browse files
Add items() method to nn.Module for state_dict iteration
Add items() method to nn.Module that returns an enumerator of (name, tensor) tuples from the module's state_dict. This enables easy iteration over all parameters and persistent buffers, consistent with the existing items() pattern in ModuleDict and ParameterDict. This addresses the core request in issue #1474 by providing the items() API needed for model merging workflows (averaging parameters between models using state_dict + load_state_dict). Changes: - Add virtual items() method to Module class - Add 'new' keyword to ModuleDict.items() and ParameterDict.items() to properly hide the base class method (different return types) - Add tests for items() on simple and nested modules - Add test demonstrating the model merge pattern from the issue Closes #1474 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 5782540 commit 850e628

File tree

4 files changed

+85
-2
lines changed

4 files changed

+85
-2
lines changed

src/TorchSharp/NN/Module.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,20 @@ public virtual void zero_grad(bool set_to_none = true)
510510
/// </summary>
511511
public virtual IEnumerable<Module> children() => named_children().Select(np => np.module);
512512

513+
/// <summary>
514+
/// Return an enumeration of the module's state_dict key/value pairs.
515+
///
516+
/// This is equivalent to calling state_dict() and iterating over its entries.
517+
/// Both parameters and persistent buffers are included.
518+
/// </summary>
519+
/// <returns>An enumerator of (name, tensor) tuples</returns>
520+
public virtual IEnumerator<(string name, Tensor value)> items()
521+
{
522+
foreach (var kv in state_dict()) {
523+
yield return (kv.Key, kv.Value);
524+
}
525+
}
526+
513527
/// <summary>
514528
/// Returns a dictionary containing a whole state of the module.
515529
///

src/TorchSharp/NN/ModuleDict.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ public void clear()
4040
/// Return an enumeration of the ParameterDict key/value pairs.
4141
/// </summary>
4242
/// <returns></returns>
43-
public IEnumerator<(string, T)> items() => _list.GetEnumerator();
43+
public new IEnumerator<(string, T)> items() => _list.GetEnumerator();
4444

4545
/// <summary>
4646
/// Return the ParameterDict keys.

src/TorchSharp/NN/ParameterDict.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ public void clear()
4040
/// Return an enumeration of the ParameterDict key/value pairs.
4141
/// </summary>
4242
/// <returns></returns>
43-
public IEnumerator<(string, Parameter)> items() => _list.GetEnumerator();
43+
public new IEnumerator<(string, Parameter)> items() => _list.GetEnumerator();
4444

4545
/// <summary>
4646
/// Return the ParameterDict keys.

test/TorchSharpTest/NN.cs

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3309,6 +3309,75 @@ public void TestCustomComponentName()
33093309
Assert.True(sd.ContainsKey("_linear2.weight"));
33103310
}
33113311

3312+
[Fact]
3313+
public void TestModuleItems()
3314+
{
3315+
var lin = Linear(10, 5, true);
3316+
var sd = lin.state_dict();
3317+
var items = new List<(string, Tensor)>();
3318+
3319+
using (var enumerator = lin.items()) {
3320+
while (enumerator.MoveNext()) {
3321+
items.Add(enumerator.Current);
3322+
}
3323+
}
3324+
3325+
// items() should return the same entries as state_dict()
3326+
Assert.Equal(sd.Count, items.Count);
3327+
foreach (var (name, value) in items) {
3328+
Assert.True(sd.ContainsKey(name));
3329+
Assert.Equal(sd[name].shape, value.shape);
3330+
}
3331+
}
3332+
3333+
[Fact]
3334+
public void TestModuleItemsWithSubmodules()
3335+
{
3336+
var seq = Sequential(
3337+
("lin1", Linear(10, 5)),
3338+
("lin2", Linear(5, 2)));
3339+
var sd = seq.state_dict();
3340+
var items = new List<(string, Tensor)>();
3341+
3342+
using (var enumerator = seq.items()) {
3343+
while (enumerator.MoveNext()) {
3344+
items.Add(enumerator.Current);
3345+
}
3346+
}
3347+
3348+
Assert.Equal(sd.Count, items.Count);
3349+
Assert.Contains(items, i => i.Item1 == "lin1.weight");
3350+
Assert.Contains(items, i => i.Item1 == "lin2.weight");
3351+
}
3352+
3353+
[Fact]
3354+
public void TestModelMergeUsingItemsAndStateDict()
3355+
{
3356+
// Demonstrate model merging pattern using items() + state_dict() + load_state_dict()
3357+
// This is how users would merge models, matching the PyTorch pattern
3358+
var model1 = Linear(10, 5, true);
3359+
var model2 = Linear(10, 5, true);
3360+
3361+
var sd1 = model1.state_dict();
3362+
var sd2 = model2.state_dict();
3363+
3364+
var merged = new Dictionary<string, Tensor>();
3365+
using (var enumerator = model1.items()) {
3366+
while (enumerator.MoveNext()) {
3367+
var (name, _) = enumerator.Current;
3368+
merged[name] = (sd1[name] + sd2[name]) / 2;
3369+
}
3370+
}
3371+
3372+
model1.load_state_dict(merged);
3373+
3374+
// Verify the merged parameters are the average
3375+
var finalSd = model1.state_dict();
3376+
foreach (var key in merged.Keys) {
3377+
Assert.True(finalSd[key].allclose(merged[key]));
3378+
}
3379+
}
3380+
33123381
private class TestModule3 : Module<Tensor, Tensor>
33133382
{
33143383
public TestModule3() : base(nameof(TestModule3)) { RegisterComponents(); }

0 commit comments

Comments
 (0)