88#include < vector>
99
1010USimpleTorchModule::USimpleTorchModule ()
11+ : ModelId(INDEX_NONE )
1112{
1213 Buffer = NULL ;
1314 BufferDims = NULL ;
@@ -54,7 +55,8 @@ bool USimpleTorchModule::LoadTorchScriptModel(FString FileName)
5455 }
5556 BufferDims = new int [16 ];
5657
57- return Module.FuncTSW_LoadScriptModel (TCHAR_TO_ANSI (*FileName));
58+ ModelId = Module.FuncTSW_LoadScriptModel (TCHAR_TO_ANSI (*FileName));
59+ bResult = (ModelId != INDEX_NONE );
5860 }
5961 else
6062 {
@@ -71,7 +73,7 @@ bool USimpleTorchModule::IsTorchModelLoaded() const
7173 bool bResult = false ;
7274 if (Module.bDllLoaded )
7375 {
74- return Module.FuncTSW_CheckModel ();
76+ return Module.FuncTSW_CheckModel (ModelId );
7577 }
7678
7779 return bResult;
@@ -89,7 +91,7 @@ bool USimpleTorchModule::ExecuteModelMethod(const FString& MethodName, const FSi
8991 bool bResult = false ;
9092 if (Module.bDllLoaded && Buffer != NULL && OutData.IsDataOwner ())
9193 {
92- TArray<int > InDims = InData.GetDimensions (). Array () ;
94+ TArray<int > InDims = InData.GetDimensions ();
9395
9496 float * pOutData = OutData.IsValid ()
9597 ? OutData.GetRawData ()
@@ -98,21 +100,21 @@ bool USimpleTorchModule::ExecuteModelMethod(const FString& MethodName, const FSi
98100 int OutDimsCount = 0 ;
99101 if (MethodName == TEXT (" forward" ))
100102 {
101- Module.FuncTSW_ForwardPass_Def (InData.GetRawData (), InDims.GetData (), InDims.Num (),
103+ Module.FuncTSW_ForwardPass_Def (ModelId, InData.GetRawData (), InDims.GetData (), InDims.Num (),
102104 pOutData, BufferDims, &OutDimsCount);
103105 }
104106 else
105107 {
106- Module.FuncTSW_Execute_Def (TCHAR_TO_ANSI (*MethodName), InData.GetRawData (), InDims.GetData (), InDims.Num (),
108+ Module.FuncTSW_Execute_Def (ModelId, TCHAR_TO_ANSI (*MethodName), InData.GetRawData (), InDims.GetData (), InDims.Num (),
107109 pOutData, BufferDims, &OutDimsCount);
108110 }
109111
110112 bResult = (OutDimsCount > 0 );
111113 if (bResult)
112114 {
113- TArray<int32> OldOutDims = OutData.GetDimensions (). Array () ;
115+ TArray<int32> OldOutDims = OutData.GetDimensions ();
114116 bool bOutTensorMatches = (OutDimsCount == OutData.GetDimensions ().Num ());
115- TSet <int32> NewOutDims;
117+ TArray <int32> NewOutDims;
116118
117119 int32 Length = 1 ;
118120 for (int i = 0 ; i < OutDimsCount; i++)
@@ -195,14 +197,14 @@ void FSimpleTorchTensor::Cleanup()
195197 }
196198}
197199
198- int32 FSimpleTorchTensor::GetAddress (TSet <int32> Address) const
200+ int32 FSimpleTorchTensor::GetAddress (TArray <int32> Address) const
199201{
200202 if (!Data)
201203 {
202204 return INDEX_NONE ;
203205 }
204206
205- TArray<int32> AddrArray = Address. Array () ;
207+ TArray<int32> AddrArray = Address;
206208 int32 Addr = 0 ;
207209 for (int32 i = 0 ; i < AddressMultipliersCache.Num (); i++)
208210 {
@@ -218,7 +220,7 @@ int32 FSimpleTorchTensor::GetAddress(TSet<int32> Address) const
218220 return Addr;
219221}
220222
221- bool FSimpleTorchTensor::Create (TSet <int32> TensorDimensions)
223+ bool FSimpleTorchTensor::Create (TArray <int32> TensorDimensions)
222224{
223225 if (Data)
224226 {
@@ -233,7 +235,7 @@ bool FSimpleTorchTensor::Create(TSet<int32> TensorDimensions)
233235
234236 if (DataSize == 0 ) return false ;
235237
236- Dimensions = TensorDimensions. Array () ;
238+ Dimensions = TensorDimensions;
237239 InitAddressSpace ();
238240
239241 Data = new float [DataSize];
@@ -242,7 +244,7 @@ bool FSimpleTorchTensor::Create(TSet<int32> TensorDimensions)
242244 return true ;
243245}
244246
245- bool FSimpleTorchTensor::CreateAsChild (FSimpleTorchTensor* Parent, TSet <int32> Address)
247+ bool FSimpleTorchTensor::CreateAsChild (FSimpleTorchTensor* Parent, TArray <int32> Address)
246248{
247249 bDataOwner = false ;
248250
@@ -296,9 +298,9 @@ bool FSimpleTorchTensor::CreateAsChild(FSimpleTorchTensor* Parent, TSet<int32> A
296298 return true ;
297299}
298300
299- TSet <int32> FSimpleTorchTensor::GetDimensions () const
301+ TArray <int32> FSimpleTorchTensor::GetDimensions () const
300302{
301- TSet <int32> t;
303+ TArray <int32> t;
302304 for (const auto & d : Dimensions)
303305 t.Add (d);
304306
@@ -329,13 +331,13 @@ float* FSimpleTorchTensor::GetRawData(int32* Size) const
329331 }
330332}
331333
332- float * FSimpleTorchTensor::GetCell (TSet <int32> Address)
334+ float * FSimpleTorchTensor::GetCell (TArray <int32> Address)
333335{
334336 int32 Addr = GetAddress (Address);
335337 return Addr == INDEX_NONE ? NULL : &Data[Addr];
336338}
337339
338- float FSimpleTorchTensor::GetValue (TSet <int32> Address) const
340+ float FSimpleTorchTensor::GetValue (TArray <int32> Address) const
339341{
340342 int32 Addr = GetAddress (Address);
341343 return Addr == INDEX_NONE ? 0 : Data[Addr];
@@ -371,7 +373,7 @@ bool FSimpleTorchTensor::FromArray(const TArray<float>& InData)
371373 return false ;
372374}
373375
374- bool FSimpleTorchTensor::Reshape (TSet <int32> NewShape)
376+ bool FSimpleTorchTensor::Reshape (TArray <int32> NewShape)
375377{
376378 if (NewShape.Num () == 0 )
377379 {
@@ -390,7 +392,7 @@ bool FSimpleTorchTensor::Reshape(TSet<int32> NewShape)
390392
391393 if (NewDataSize == DataSize)
392394 {
393- Dimensions = NewShape. Array () ;
395+ Dimensions = NewShape;
394396 InitAddressSpace ();
395397 }
396398 else
@@ -400,7 +402,7 @@ bool FSimpleTorchTensor::Reshape(TSet<int32> NewShape)
400402 Cleanup ();
401403
402404 DataSize = NewDataSize;
403- Dimensions = NewShape. Array () ;
405+ Dimensions = NewShape;
404406 Data = new float [DataSize];
405407
406408 InitAddressSpace ();
@@ -422,7 +424,7 @@ FSimpleTorchTensor FSimpleTorchTensor::Detach()
422424 return FSimpleTorchTensor ();
423425 }
424426
425- TSet <int32> Dims;
427+ TArray <int32> Dims;
426428 for (const auto & Val : Dimensions) Dims.Add (Val);
427429
428430 FSimpleTorchTensor ret = FSimpleTorchTensor (Dims);
0 commit comments