@@ -40,14 +40,25 @@ public FastTextWrapper()
4040 _fastText = CreateFastText ( ) ;
4141 }
4242
43+ /// <summary>
44+ /// Loads a trained model from a byte array.
45+ /// </summary>
46+ /// <param name="bytes">Bytes array containing the model (.bin file).</param>
47+ public void LoadModel ( byte [ ] bytes )
48+ {
49+ LoadModelData ( _fastText , bytes , bytes . Length ) ;
50+ _maxLabelLen = GetMaxLabelLength ( _fastText ) ;
51+ _modelLoaded = true ;
52+ }
53+
4354 /// <summary>
4455 /// Loads a trained model from a file.
4556 /// </summary>
4657 /// <param name="path">Path to a model (.bin file).</param>
4758 public void LoadModel ( string path )
4859 {
4960 LoadModel ( _fastText , path ) ;
50- _maxLabelLen = GetMaxLabelLenght ( _fastText ) ;
61+ _maxLabelLen = GetMaxLabelLength ( _fastText ) ;
5162 _modelLoaded = true ;
5263 }
5364
@@ -74,6 +85,34 @@ public unsafe string[] GetLabels()
7485 return result ;
7586 }
7687
88+ /// <summary>
89+ /// Calculate nearest neighbors from input text.
90+ /// </summary>
91+ /// <param name="text">Text to calculate nearest neighbors from.</param>
92+ /// <param name="number">Number of neighbors.</param>
93+ /// <returns>Nearest neighbor predictions.</returns>
94+ public unsafe Prediction [ ] GetNN ( string text , int number )
95+ {
96+ CheckModelLoaded ( ) ;
97+
98+ var probs = new float [ number ] ;
99+ IntPtr labelsPtr ;
100+
101+ int cnt = GetNN ( _fastText , _utf8 . GetBytes ( text ) , new IntPtr ( & labelsPtr ) , probs , number ) ;
102+ var result = new Prediction [ cnt ] ;
103+
104+ for ( int i = 0 ; i < cnt ; i ++ )
105+ {
106+ var ptr = Marshal . ReadIntPtr ( labelsPtr , i * IntPtr . Size ) ;
107+ string label = _utf8 . GetString ( GetStringBytes ( ptr ) ) ;
108+ result [ i ] = new Prediction ( probs [ i ] , label ) ;
109+ }
110+
111+ DestroyStrings ( labelsPtr , cnt ) ;
112+
113+ return result ;
114+ }
115+
77116 /// <summary>
78117 /// Predicts a single label from input text.
79118 /// </summary>
@@ -169,7 +208,7 @@ public void Train(string inputPath, string outputPath, SupervisedArgs args)
169208 } ;
170209
171210 TrainSupervised ( _fastText , inputPath , outputPath , argsStruct , args . LabelPrefix ) ;
172- _maxLabelLen = GetMaxLabelLenght ( _fastText ) ;
211+ _maxLabelLen = GetMaxLabelLength ( _fastText ) ;
173212 _modelLoaded = true ;
174213 }
175214
@@ -214,7 +253,7 @@ public void Train(string inputPath, string outputPath, FastTextArgs args)
214253 } ;
215254
216255 Train ( _fastText , inputPath , outputPath , argsStruct , args . LabelPrefix , args . PretrainedVectors ) ;
217- _maxLabelLen = GetMaxLabelLenght ( _fastText ) ;
256+ _maxLabelLen = GetMaxLabelLength ( _fastText ) ;
218257 _modelLoaded = true ;
219258 }
220259
0 commit comments