11// VR IK Body Plugin
2- // (c) Yuri N Kalinin, 2021, ykasczc@gmail.com. All right reserved.
2+ // (c) Yuri N Kalinin, 2021-2022 , ykasczc@gmail.com. All right reserved.
33
44#include " SimpleTorchModule.h"
55#include " SimplePyTorch.h"
@@ -32,7 +32,9 @@ void USimpleTorchModule::BeginDestroy()
3232
3333USimpleTorchModule* USimpleTorchModule::CreateSimpleTorchModule (UObject* InParent)
3434{
35- return NewObject<USimpleTorchModule>(InParent);
35+ return InParent
36+ ? NewObject<USimpleTorchModule>(InParent)
37+ : NewObject<USimpleTorchModule>();
3638}
3739
3840bool USimpleTorchModule::LoadTorchScriptModel (FString FileName)
@@ -71,9 +73,9 @@ bool USimpleTorchModule::IsTorchModelLoaded() const
7173 FSimplePyTorchModule& Module = FModuleManager::GetModuleChecked<FSimplePyTorchModule>(TEXT (" SimplePyTorch" ));
7274
7375 bool bResult = false ;
74- if (Module.bDllLoaded )
76+ if (Module.bDllLoaded && Module. FuncTSW_CheckModel )
7577 {
76- return Module.FuncTSW_CheckModel (ModelId);
78+ bResult = Module.FuncTSW_CheckModel (ModelId);
7779 }
7880
7981 return bResult;
@@ -91,7 +93,7 @@ bool USimpleTorchModule::ExecuteModelMethod(const FString& MethodName, const FSi
9193 bool bResult = false ;
9294 if (Module.bDllLoaded && Buffer != NULL && OutData.IsDataOwner ())
9395 {
94- TArray<int > InDims = InData.GetDimensions ();
96+ TArray<int > InDims = InData.GetDimensions (). Array () ;
9597
9698 float * pOutData = OutData.IsValid ()
9799 ? OutData.GetRawData ()
@@ -100,21 +102,25 @@ bool USimpleTorchModule::ExecuteModelMethod(const FString& MethodName, const FSi
100102 int OutDimsCount = 0 ;
101103 if (MethodName == TEXT (" forward" ))
102104 {
105+ if (Module.FuncTSW_ForwardPass_Def == NULL ) return false ;
106+
103107 Module.FuncTSW_ForwardPass_Def (ModelId, InData.GetRawData (), InDims.GetData (), InDims.Num (),
104108 pOutData, BufferDims, &OutDimsCount);
105109 }
106110 else
107111 {
112+ if (Module.FuncTSW_Execute_Def == NULL ) return false ;
113+
108114 Module.FuncTSW_Execute_Def (ModelId, TCHAR_TO_ANSI (*MethodName), InData.GetRawData (), InDims.GetData (), InDims.Num (),
109115 pOutData, BufferDims, &OutDimsCount);
110116 }
111117
112- bResult = (OutDimsCount > 0 );
118+ bResult = (OutDimsCount > 0 ) && pOutData != NULL && BufferDims != NULL ;
113119 if (bResult)
114120 {
115- TArray<int32> OldOutDims = OutData.GetDimensions ();
121+ TArray<int32> OldOutDims = OutData.GetDimensions (). Array () ;
116122 bool bOutTensorMatches = (OutDimsCount == OutData.GetDimensions ().Num ());
117- TArray <int32> NewOutDims;
123+ TSet <int32> NewOutDims;
118124
119125 int32 Length = 1 ;
120126 for (int i = 0 ; i < OutDimsCount; i++)
@@ -197,14 +203,14 @@ void FSimpleTorchTensor::Cleanup()
197203 }
198204}
199205
200- int32 FSimpleTorchTensor::GetAddress (TArray <int32> Address) const
206+ int32 FSimpleTorchTensor::GetAddress (TSet <int32> Address) const
201207{
202208 if (!Data)
203209 {
204210 return INDEX_NONE ;
205211 }
206212
207- TArray<int32> AddrArray = Address;
213+ TArray<int32> AddrArray = Address. Array () ;
208214 int32 Addr = 0 ;
209215 for (int32 i = 0 ; i < AddressMultipliersCache.Num (); i++)
210216 {
@@ -220,7 +226,7 @@ int32 FSimpleTorchTensor::GetAddress(TArray<int32> Address) const
220226 return Addr;
221227}
222228
223- bool FSimpleTorchTensor::Create (TArray <int32> TensorDimensions)
229+ bool FSimpleTorchTensor::Create (TSet <int32> TensorDimensions)
224230{
225231 if (Data)
226232 {
@@ -235,7 +241,7 @@ bool FSimpleTorchTensor::Create(TArray<int32> TensorDimensions)
235241
236242 if (DataSize == 0 ) return false ;
237243
238- Dimensions = TensorDimensions;
244+ Dimensions = TensorDimensions. Array () ;
239245 InitAddressSpace ();
240246
241247 Data = new float [DataSize];
@@ -244,7 +250,7 @@ bool FSimpleTorchTensor::Create(TArray<int32> TensorDimensions)
244250 return true ;
245251}
246252
247- bool FSimpleTorchTensor::CreateAsChild (FSimpleTorchTensor* Parent, TArray <int32> Address)
253+ bool FSimpleTorchTensor::CreateAsChild (FSimpleTorchTensor* Parent, TSet <int32> Address)
248254{
249255 bDataOwner = false ;
250256
@@ -298,9 +304,9 @@ bool FSimpleTorchTensor::CreateAsChild(FSimpleTorchTensor* Parent, TArray<int32>
298304 return true ;
299305}
300306
301- TArray <int32> FSimpleTorchTensor::GetDimensions () const
307+ TSet <int32> FSimpleTorchTensor::GetDimensions () const
302308{
303- TArray <int32> t;
309+ TSet <int32> t;
304310 for (const auto & d : Dimensions)
305311 t.Add (d);
306312
@@ -331,13 +337,13 @@ float* FSimpleTorchTensor::GetRawData(int32* Size) const
331337 }
332338}
333339
334- float * FSimpleTorchTensor::GetCell (TArray <int32> Address)
340+ float * FSimpleTorchTensor::GetCell (TSet <int32> Address)
335341{
336342 int32 Addr = GetAddress (Address);
337343 return Addr == INDEX_NONE ? NULL : &Data[Addr];
338344}
339345
340- float FSimpleTorchTensor::GetValue (TArray <int32> Address) const
346+ float FSimpleTorchTensor::GetValue (TSet <int32> Address) const
341347{
342348 int32 Addr = GetAddress (Address);
343349 return Addr == INDEX_NONE ? 0 : Data[Addr];
@@ -373,7 +379,7 @@ bool FSimpleTorchTensor::FromArray(const TArray<float>& InData)
373379 return false ;
374380}
375381
376- bool FSimpleTorchTensor::Reshape (TArray <int32> NewShape)
382+ bool FSimpleTorchTensor::Reshape (TSet <int32> NewShape)
377383{
378384 if (NewShape.Num () == 0 )
379385 {
@@ -392,7 +398,7 @@ bool FSimpleTorchTensor::Reshape(TArray<int32> NewShape)
392398
393399 if (NewDataSize == DataSize)
394400 {
395- Dimensions = NewShape;
401+ Dimensions = NewShape. Array () ;
396402 InitAddressSpace ();
397403 }
398404 else
@@ -402,7 +408,7 @@ bool FSimpleTorchTensor::Reshape(TArray<int32> NewShape)
402408 Cleanup ();
403409
404410 DataSize = NewDataSize;
405- Dimensions = NewShape;
411+ Dimensions = NewShape. Array () ;
406412 Data = new float [DataSize];
407413
408414 InitAddressSpace ();
@@ -424,7 +430,7 @@ FSimpleTorchTensor FSimpleTorchTensor::Detach()
424430 return FSimpleTorchTensor ();
425431 }
426432
427- TArray <int32> Dims;
433+ TSet <int32> Dims;
428434 for (const auto & Val : Dimensions) Dims.Add (Val);
429435
430436 FSimpleTorchTensor ret = FSimpleTorchTensor (Dims);
0 commit comments