Skip to content

Commit ab971d0

Browse files
committed
Merge branch 'NoorDigitalAgency-master'
2 parents 63d5656 + 83b7737 commit ab971d0

5 files changed

Lines changed: 73 additions & 14 deletions

File tree

FastText.NetWrapper/FastText.NetWrapper.csproj

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@
1919
</PropertyGroup>
2020

2121
<ItemGroup>
22-
<PackageReference Include="FastText.Native.Linux" Version="1.0.73" />
23-
<PackageReference Include="FastText.Native.MacOs" Version="1.0.74" />
24-
<PackageReference Include="FastText.Native.Windows" Version="1.0.72" />
25-
<PackageReference Include="LibLog" Version="5.0.6" />
26-
<PackageReference Include="NativeLibraryManager" Version="1.0.14" />
22+
<PackageReference Include="FastText.Native.Linux" Version="1.0.84" />
23+
<PackageReference Include="FastText.Native.MacOs" Version="1.0.84" />
24+
<PackageReference Include="FastText.Native.Windows" Version="1.0.84" />
25+
<PackageReference Include="LibLog" Version="5.0.8" />
26+
<PackageReference Include="NativeLibraryManager" Version="1.0.18" />
2727
</ItemGroup>
2828

2929
</Project>

FastText.NetWrapper/FastTextWrapper.Api.cs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,18 @@ private struct TrainingArgsStruct
7171

7272
[DllImport(FastTextDll)]
7373
private static extern void LoadModel(IntPtr hPtr, string path);
74+
75+
[DllImport(FastTextDll)]
76+
private static extern void LoadModelData(IntPtr hPtr, byte[] data, long length);
7477

7578
[DllImport(FastTextDll)]
76-
private static extern int GetMaxLabelLenght(IntPtr hPtr);
79+
private static extern int GetMaxLabelLength(IntPtr hPtr);
7780

7881
[DllImport(FastTextDll)]
7982
private static extern int GetLabels(IntPtr hPtr, IntPtr labels);
83+
84+
[DllImport(FastTextDll)]
85+
private static extern int GetNN(IntPtr hPtr, byte[] input, IntPtr predictedLabels, float[] predictedProbabilities, int n);
8086

8187
[DllImport(FastTextDll)]
8288
private static extern float PredictSingle(IntPtr hPtr, byte[] input, IntPtr predicted);

FastText.NetWrapper/FastTextWrapper.cs

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,25 @@ public FastTextWrapper()
4040
_fastText = CreateFastText();
4141
}
4242

43+
/// <summary>
44+
/// Loads a trained model from a byte array.
45+
/// </summary>
46+
/// <param name="bytes">Bytes array containing the model (.bin file).</param>
47+
public void LoadModel(byte[] bytes)
48+
{
49+
LoadModelData(_fastText, bytes, bytes.Length);
50+
_maxLabelLen = GetMaxLabelLength(_fastText);
51+
_modelLoaded = true;
52+
}
53+
4354
/// <summary>
4455
/// Loads a trained model from a file.
4556
/// </summary>
4657
/// <param name="path">Path to a model (.bin file).</param>
4758
public void LoadModel(string path)
4859
{
4960
LoadModel(_fastText, path);
50-
_maxLabelLen = GetMaxLabelLenght(_fastText);
61+
_maxLabelLen = GetMaxLabelLength(_fastText);
5162
_modelLoaded = true;
5263
}
5364

@@ -74,6 +85,34 @@ public unsafe string[] GetLabels()
7485
return result;
7586
}
7687

88+
/// <summary>
89+
/// Calculate nearest neighbors from input text.
90+
/// </summary>
91+
/// <param name="text">Text to calculate nearest neighbors from.</param>
92+
/// <param name="number">Number of neighbors.</param>
93+
/// <returns>Nearest neighbor predictions.</returns>
94+
public unsafe Prediction[] GetNN(string text, int number)
95+
{
96+
CheckModelLoaded();
97+
98+
var probs = new float[number];
99+
IntPtr labelsPtr;
100+
101+
int cnt = GetNN(_fastText, _utf8.GetBytes(text), new IntPtr(&labelsPtr), probs, number);
102+
var result = new Prediction[cnt];
103+
104+
for (int i = 0; i < cnt; i++)
105+
{
106+
var ptr = Marshal.ReadIntPtr(labelsPtr, i * IntPtr.Size);
107+
string label = _utf8.GetString(GetStringBytes(ptr));
108+
result[i] = new Prediction(probs[i], label);
109+
}
110+
111+
DestroyStrings(labelsPtr, cnt);
112+
113+
return result;
114+
}
115+
77116
/// <summary>
78117
/// Predicts a single label from input text.
79118
/// </summary>
@@ -169,7 +208,7 @@ public void Train(string inputPath, string outputPath, SupervisedArgs args)
169208
};
170209

171210
TrainSupervised(_fastText, inputPath, outputPath, argsStruct, args.LabelPrefix);
172-
_maxLabelLen = GetMaxLabelLenght(_fastText);
211+
_maxLabelLen = GetMaxLabelLength(_fastText);
173212
_modelLoaded = true;
174213
}
175214

@@ -214,7 +253,7 @@ public void Train(string inputPath, string outputPath, FastTextArgs args)
214253
};
215254

216255
Train(_fastText, inputPath, outputPath, argsStruct, args.LabelPrefix, args.PretrainedVectors);
217-
_maxLabelLen = GetMaxLabelLenght(_fastText);
256+
_maxLabelLen = GetMaxLabelLength(_fastText);
218257
_modelLoaded = true;
219258
}
220259

TestUtil/Program.cs

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,12 @@ namespace TestUtil
1010
{
1111
class Program
1212
{
13-
private const string Usage = "Usage: tesutil [train|trainlowlevel|load] train_file model_file";
13+
private static string Usage = "Usage: tesutil [train|trainlowlevel|load] train_file model_file\n" +
14+
"Usage: testutil nn model_file";
1415

1516
static void Main(string[] args)
1617
{
17-
if (args.Length < 3)
18+
if ((args.FirstOrDefault() == "nn" && args.Length < 2) || (args.FirstOrDefault() != "nn" && args.Length < 3))
1819
{
1920
Console.WriteLine(Usage);
2021
return;
@@ -34,8 +35,16 @@ static void Main(string[] args)
3435
fastText.LoadModel(args[2]);
3536
break;
3637
}
37-
38-
Test(fastText);
38+
39+
if (args[0] != "nn")
40+
{
41+
Test(fastText);
42+
}
43+
else
44+
{
45+
fastText.LoadModel(File.ReadAllBytes(args[1]));
46+
TestNN(fastText);
47+
}
3948
}
4049
}
4150

@@ -67,5 +76,10 @@ private static void Test(FastTextWrapper fastText)
6776
var predictions = fastText.PredictMultiple("Can I use a larger crockpot than the recipe calls for?", 4);
6877
var vector = fastText.GetSentenceVector("Can I use a larger crockpot than the recipe calls for?");
6978
}
79+
80+
private static void TestNN(FastTextWrapper fastText)
81+
{
82+
var predictions = fastText.GetNN("train", 5);
83+
}
7084
}
7185
}

TestUtil/TestUtil.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
<PropertyGroup>
44
<OutputType>Exe</OutputType>
5-
<TargetFramework>netcoreapp2.2</TargetFramework>
5+
<TargetFramework>netcoreapp3.1</TargetFramework>
66
</PropertyGroup>
77

88
<ItemGroup>

0 commit comments

Comments
 (0)