-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Expand file tree
/
Copy pathRotaryEmbedding.cs
More file actions
125 lines (100 loc) · 3.66 KB
/
RotaryEmbedding.cs
File metadata and controls
125 lines (100 loc) · 3.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Text.Json.Serialization;
using TorchSharp;
using static TorchSharp.torch;
namespace Microsoft.ML.GenAI.Core;
public class RopeScalingConfig
{
public RopeScalingConfig()
{
this.Factor = 1.0f;
this.LowFreqFactor = 1.0f;
this.HighFreqFactor = 1.0f;
this.OriginalMaxPositionEmbeddings = 8192;
this.RopeType = "default";
}
[JsonPropertyName("factor")]
public float Factor { get; set; }
[JsonPropertyName("low_freq_factor")]
public float LowFreqFactor { get; set; }
[JsonPropertyName("high_freq_factor")]
public float HighFreqFactor { get; set; }
[JsonPropertyName("original_max_position_embeddings")]
public int OriginalMaxPositionEmbeddings { get; set; }
[JsonPropertyName("rope_type")]
public string RopeType { get; set; }
}
internal class RotaryEmbeddingInput
{
public RotaryEmbeddingInput(Tensor input, Tensor positionIds, int? seqLen = null)
{
Input = input;
PositionIds = positionIds;
SeqLen = seqLen;
}
public Tensor Input { get; set; }
public Tensor PositionIds { get; set; }
public int? SeqLen { get; set; }
}
internal class RotaryEmbeddingOutput
{
public RotaryEmbeddingOutput(Tensor cos, Tensor sin)
{
Cos = cos;
Sin = sin;
}
public Tensor Cos { get; set; }
public Tensor Sin { get; set; }
}
internal class RotaryEmbedding : nn.Module<
RotaryEmbeddingInput,
RotaryEmbeddingOutput>
{
private readonly double _base;
private readonly int _maxPositionEmbeddings;
private readonly int _dim;
public RotaryEmbedding(double baseValue, int maxPositionEmbeddings, int dim)
: this(baseValue, dim, new RopeScalingConfig() { RopeType = "default", OriginalMaxPositionEmbeddings = maxPositionEmbeddings })
{
}
public RotaryEmbedding(double baseValue, int dim, RopeScalingConfig config)
: base(nameof(RotaryEmbedding))
{
_base = baseValue;
_maxPositionEmbeddings = config.OriginalMaxPositionEmbeddings;
_dim = dim;
if (config.RopeType == "default")
{
var thetaNumerator = torch.arange(0, _dim, 2, dtype: ScalarType.Int64).to(torch.float32);
this.register_buffer("inv_freq", torch.pow(baseValue, -1.0f * (thetaNumerator / dim)), persistent: false);
}
else
{
throw new NotImplementedException("Rope type not implemented");
}
}
public int Dim => _dim;
#pragma warning disable MSML_GeneralName // This name should be PascalCased
public override RotaryEmbeddingOutput forward(RotaryEmbeddingInput input)
#pragma warning restore MSML_GeneralName // This name should be PascalCased
{
var x = input.Input;
var positionIds = input.PositionIds;
var seqLen = input.SeqLen;
// TODO
// can be calculated once and cached
var invFreq = this.get_buffer("inv_freq")!.to(x.device);
var invFreqExpanded = invFreq.unsqueeze(0).unsqueeze(-1);
invFreqExpanded = invFreqExpanded.expand(new long[] { positionIds.shape[0], -1, 1 });
var positionIdsExpanded = positionIds.unsqueeze(1).to(torch.float32);
var freqs = invFreqExpanded * positionIdsExpanded;
freqs = freqs.transpose(1, 2);
var emb = torch.cat([freqs, freqs], dim: -1);
var cos = torch.cos(emb);
var sin = torch.sin(emb);
return new(cos.to_type(x.dtype), sin.to_type(x.dtype));
}
}