Skip to content

Commit e5bde76

Browse files
authored
Third-party dlls management
- LibTorch DLLs aren't copied to plugin's Binaries directory anymore - torchscript_wrapper.dll udpated to PyTorch Build 1.10.1 - copyrights updated
1 parent db0036c commit e5bde76

6 files changed

Lines changed: 83 additions & 63 deletions

File tree

Source/SimplePyTorch/Private/SimplePyTorch.cpp

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// VR IK Body Plugin
2-
// (c) Yuri N Kalinin, 2021, ykasczc@gmail.com. All right reserved.
2+
// (c) Yuri N Kalinin, 2021-2022, ykasczc@gmail.com. All right reserved.
33

44
#include "SimplePyTorch.h"
55
#include "HAL/PlatformProcess.h"
@@ -14,27 +14,32 @@ void FSimplePyTorchModule::StartupModule()
1414
FString FilePath;
1515
const FString szBinaries = TEXT("Binaries");
1616
const FString szPlatform = TEXT("Win64");
17-
17+
1818
#if WITH_EDITOR
1919
auto ThisPlugin = IPluginManager::Get().FindPlugin(TEXT("SimplePyTorch"));
2020
if (ThisPlugin.IsValid())
2121
{
2222
FilePath = FPaths::ConvertRelativePathToFull(ThisPlugin->GetBaseDir());
23-
24-
FString PluginBinariesDir = FilePath / TEXT("Source/ThirdParty/pytorch") / szBinaries / szPlatform;
25-
UE_LOG(LogTemp, Log, TEXT("PyTorch third-party dlls directory: %s"), *PluginBinariesDir);
26-
FPlatformProcess::PushDllDirectory(*PluginBinariesDir);
23+
FilePath = FilePath / TEXT("Source/ThirdParty/pytorch") / szBinaries / szPlatform;
24+
}
25+
else
26+
{
27+
FilePath = FPaths::ProjectDir() / TEXT("Binaries/ThirdParty/PyTorch");
2728
}
2829
#else
29-
FilePath = FPaths::ConvertRelativePathToFull(FPaths::ProjectDir());
30-
#endif
31-
FilePath = FilePath / szBinaries / szPlatform / TEXT("torchscript_wrapper.dll");
30+
FilePath = FPaths::ProjectDir() / TEXT("Binaries/ThirdParty/PyTorch");
31+
#endif
32+
FPlatformProcess::PushDllDirectory(*FilePath);
33+
FilePath = FilePath / TEXT("torchscript_wrapper.dll");
34+
3235
WrapperDllHandle = NULL;
3336
bDllLoaded = false;
34-
37+
3538
#if PLATFORM_WINDOWS
3639
if (FPaths::FileExists(FilePath))
3740
{
41+
UE_LOG(LogTemp, Log, TEXT("SimplePyTorchModule: Loading torch wrapper from %s"), *FilePath);
42+
3843
WrapperDllHandle = FPlatformProcess::GetDllHandle(*FilePath);
3944

4045
if (WrapperDllHandle != NULL)
@@ -91,6 +96,14 @@ void FSimplePyTorchModule::ShutdownModule()
9196
if (WrapperDllHandle != NULL)
9297
{
9398
FPlatformProcess::FreeDllHandle(WrapperDllHandle);
99+
WrapperDllHandle = NULL;
100+
bDllLoaded = false;
101+
FuncTSW_LoadScriptModel = NULL;
102+
FuncTSW_CheckModel = NULL;
103+
FuncTSW_Forward1d = NULL;
104+
FuncTSW_ForwardTensor = NULL;
105+
FuncTSW_ForwardPass_Def = NULL;
106+
FuncTSW_Execute_Def = NULL;
94107
}
95108
}
96109

Source/SimplePyTorch/Private/SimpleTorchModule.cpp

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// VR IK Body Plugin
2-
// (c) Yuri N Kalinin, 2021, ykasczc@gmail.com. All right reserved.
2+
// (c) Yuri N Kalinin, 2021-2022, ykasczc@gmail.com. All right reserved.
33

44
#include "SimpleTorchModule.h"
55
#include "SimplePyTorch.h"
@@ -32,7 +32,9 @@ void USimpleTorchModule::BeginDestroy()
3232

3333
USimpleTorchModule* USimpleTorchModule::CreateSimpleTorchModule(UObject* InParent)
3434
{
35-
return NewObject<USimpleTorchModule>(InParent);
35+
return InParent
36+
? NewObject<USimpleTorchModule>(InParent)
37+
: NewObject<USimpleTorchModule>();
3638
}
3739

3840
bool USimpleTorchModule::LoadTorchScriptModel(FString FileName)
@@ -71,9 +73,9 @@ bool USimpleTorchModule::IsTorchModelLoaded() const
7173
FSimplePyTorchModule& Module = FModuleManager::GetModuleChecked<FSimplePyTorchModule>(TEXT("SimplePyTorch"));
7274

7375
bool bResult = false;
74-
if (Module.bDllLoaded)
76+
if (Module.bDllLoaded && Module.FuncTSW_CheckModel)
7577
{
76-
return Module.FuncTSW_CheckModel(ModelId);
78+
bResult = Module.FuncTSW_CheckModel(ModelId);
7779
}
7880

7981
return bResult;
@@ -91,7 +93,7 @@ bool USimpleTorchModule::ExecuteModelMethod(const FString& MethodName, const FSi
9193
bool bResult = false;
9294
if (Module.bDllLoaded && Buffer != NULL && OutData.IsDataOwner())
9395
{
94-
TArray<int> InDims = InData.GetDimensions();
96+
TArray<int> InDims = InData.GetDimensions().Array();
9597

9698
float* pOutData = OutData.IsValid()
9799
? OutData.GetRawData()
@@ -100,21 +102,25 @@ bool USimpleTorchModule::ExecuteModelMethod(const FString& MethodName, const FSi
100102
int OutDimsCount = 0;
101103
if (MethodName == TEXT("forward"))
102104
{
105+
if (Module.FuncTSW_ForwardPass_Def == NULL) return false;
106+
103107
Module.FuncTSW_ForwardPass_Def(ModelId, InData.GetRawData(), InDims.GetData(), InDims.Num(),
104108
pOutData, BufferDims, &OutDimsCount);
105109
}
106110
else
107111
{
112+
if (Module.FuncTSW_Execute_Def == NULL) return false;
113+
108114
Module.FuncTSW_Execute_Def(ModelId, TCHAR_TO_ANSI(*MethodName), InData.GetRawData(), InDims.GetData(), InDims.Num(),
109115
pOutData, BufferDims, &OutDimsCount);
110116
}
111117

112-
bResult = (OutDimsCount > 0);
118+
bResult = (OutDimsCount > 0) && pOutData != NULL && BufferDims != NULL;
113119
if (bResult)
114120
{
115-
TArray<int32> OldOutDims = OutData.GetDimensions();
121+
TArray<int32> OldOutDims = OutData.GetDimensions().Array();
116122
bool bOutTensorMatches = (OutDimsCount == OutData.GetDimensions().Num());
117-
TArray<int32> NewOutDims;
123+
TSet<int32> NewOutDims;
118124

119125
int32 Length = 1;
120126
for (int i = 0; i < OutDimsCount; i++)
@@ -197,14 +203,14 @@ void FSimpleTorchTensor::Cleanup()
197203
}
198204
}
199205

200-
int32 FSimpleTorchTensor::GetAddress(TArray<int32> Address) const
206+
int32 FSimpleTorchTensor::GetAddress(TSet<int32> Address) const
201207
{
202208
if (!Data)
203209
{
204210
return INDEX_NONE;
205211
}
206212

207-
TArray<int32> AddrArray = Address;
213+
TArray<int32> AddrArray = Address.Array();
208214
int32 Addr = 0;
209215
for (int32 i = 0; i < AddressMultipliersCache.Num(); i++)
210216
{
@@ -220,7 +226,7 @@ int32 FSimpleTorchTensor::GetAddress(TArray<int32> Address) const
220226
return Addr;
221227
}
222228

223-
bool FSimpleTorchTensor::Create(TArray<int32> TensorDimensions)
229+
bool FSimpleTorchTensor::Create(TSet<int32> TensorDimensions)
224230
{
225231
if (Data)
226232
{
@@ -235,7 +241,7 @@ bool FSimpleTorchTensor::Create(TArray<int32> TensorDimensions)
235241

236242
if (DataSize == 0) return false;
237243

238-
Dimensions = TensorDimensions;
244+
Dimensions = TensorDimensions.Array();
239245
InitAddressSpace();
240246

241247
Data = new float[DataSize];
@@ -244,7 +250,7 @@ bool FSimpleTorchTensor::Create(TArray<int32> TensorDimensions)
244250
return true;
245251
}
246252

247-
bool FSimpleTorchTensor::CreateAsChild(FSimpleTorchTensor* Parent, TArray<int32> Address)
253+
bool FSimpleTorchTensor::CreateAsChild(FSimpleTorchTensor* Parent, TSet<int32> Address)
248254
{
249255
bDataOwner = false;
250256

@@ -298,9 +304,9 @@ bool FSimpleTorchTensor::CreateAsChild(FSimpleTorchTensor* Parent, TArray<int32>
298304
return true;
299305
}
300306

301-
TArray<int32> FSimpleTorchTensor::GetDimensions() const
307+
TSet<int32> FSimpleTorchTensor::GetDimensions() const
302308
{
303-
TArray<int32> t;
309+
TSet<int32> t;
304310
for (const auto& d : Dimensions)
305311
t.Add(d);
306312

@@ -331,13 +337,13 @@ float* FSimpleTorchTensor::GetRawData(int32* Size) const
331337
}
332338
}
333339

334-
float* FSimpleTorchTensor::GetCell(TArray<int32> Address)
340+
float* FSimpleTorchTensor::GetCell(TSet<int32> Address)
335341
{
336342
int32 Addr = GetAddress(Address);
337343
return Addr == INDEX_NONE ? NULL : &Data[Addr];
338344
}
339345

340-
float FSimpleTorchTensor::GetValue(TArray<int32> Address) const
346+
float FSimpleTorchTensor::GetValue(TSet<int32> Address) const
341347
{
342348
int32 Addr = GetAddress(Address);
343349
return Addr == INDEX_NONE ? 0 : Data[Addr];
@@ -373,7 +379,7 @@ bool FSimpleTorchTensor::FromArray(const TArray<float>& InData)
373379
return false;
374380
}
375381

376-
bool FSimpleTorchTensor::Reshape(TArray<int32> NewShape)
382+
bool FSimpleTorchTensor::Reshape(TSet<int32> NewShape)
377383
{
378384
if (NewShape.Num() == 0)
379385
{
@@ -392,7 +398,7 @@ bool FSimpleTorchTensor::Reshape(TArray<int32> NewShape)
392398

393399
if (NewDataSize == DataSize)
394400
{
395-
Dimensions = NewShape;
401+
Dimensions = NewShape.Array();
396402
InitAddressSpace();
397403
}
398404
else
@@ -402,7 +408,7 @@ bool FSimpleTorchTensor::Reshape(TArray<int32> NewShape)
402408
Cleanup();
403409

404410
DataSize = NewDataSize;
405-
Dimensions = NewShape;
411+
Dimensions = NewShape.Array();
406412
Data = new float[DataSize];
407413

408414
InitAddressSpace();
@@ -424,7 +430,7 @@ FSimpleTorchTensor FSimpleTorchTensor::Detach()
424430
return FSimpleTorchTensor();
425431
}
426432

427-
TArray<int32> Dims;
433+
TSet<int32> Dims;
428434
for (const auto& Val : Dimensions) Dims.Add(Val);
429435

430436
FSimpleTorchTensor ret = FSimpleTorchTensor(Dims);

Source/SimplePyTorch/Public/SimplePyTorch.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// VR IK Body Plugin
2-
// (c) Yuri N Kalinin, 2021, ykasczc@gmail.com. All right reserved.
2+
// (c) Yuri N Kalinin, 2021-2022, ykasczc@gmail.com. All right reserved.
33

44
#pragma once
55

Source/SimplePyTorch/Public/SimpleTorchModule.h

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// VR IK Body Plugin
2-
// (c) Yuri N Kalinin, 2021, ykasczc@gmail.com. All right reserved.
2+
// (c) Yuri N Kalinin, 2021-2022, ykasczc@gmail.com. All right reserved.
33

44
#pragma once
55

@@ -53,7 +53,7 @@ struct SIMPLEPYTORCH_API FSimpleTorchTensor
5353
void InitAddressSpace();
5454

5555
/** Get flat address in Data from multidimensional address */
56-
int32 GetAddress(TArray<int32> Address) const;
56+
int32 GetAddress(TSet<int32> Address) const;
5757
public:
5858

5959
FSimpleTorchTensor()
@@ -62,15 +62,15 @@ struct SIMPLEPYTORCH_API FSimpleTorchTensor
6262
, DataSize(0)
6363
, ParentTensor(nullptr)
6464
{}
65-
FSimpleTorchTensor(FSimpleTorchTensor* Parent, TArray<int32> SubAddress)
65+
FSimpleTorchTensor(FSimpleTorchTensor* Parent, TSet<int32> SubAddress)
6666
: Data(NULL)
6767
, bDataOwner(true)
6868
, DataSize(0)
6969
, ParentTensor(nullptr)
7070
{
7171
CreateAsChild(Parent, SubAddress);
7272
}
73-
FSimpleTorchTensor(TArray<int32> Dimensions)
73+
FSimpleTorchTensor(TSet<int32> Dimensions)
7474
: Data(NULL)
7575
, bDataOwner(true)
7676
, DataSize(0)
@@ -87,10 +87,10 @@ struct SIMPLEPYTORCH_API FSimpleTorchTensor
8787
void Cleanup();
8888

8989
// Set dimensions and allocate memory
90-
bool Create(TArray<int32> TensorDimensions);
90+
bool Create(TSet<int32> TensorDimensions);
9191

9292
// Create tensor as a subtensor in another tensor (share the same memory)
93-
bool CreateAsChild(FSimpleTorchTensor* Parent, TArray<int32> Address);
93+
bool CreateAsChild(FSimpleTorchTensor* Parent, TSet<int32> Address);
9494

9595
// Is tensor initialized?
9696
bool IsValid() const { return Data != NULL; }
@@ -99,7 +99,7 @@ struct SIMPLEPYTORCH_API FSimpleTorchTensor
9999
bool IsDataOwner() const { return bDataOwner; }
100100

101101
// Get current tensor dimensions
102-
TArray<int32> GetDimensions() const;
102+
TSet<int32> GetDimensions() const;
103103

104104
// Get number of itemes in flat array
105105
int32 GetDataSize() const { return DataSize; }
@@ -112,11 +112,11 @@ struct SIMPLEPYTORCH_API FSimpleTorchTensor
112112
float* GetRawData() const { return Data; }
113113

114114
// Convert multidimensional address to flat address
115-
int32 GetRawAddress(TArray<int32> Address) const { return GetAddress(Address); }
115+
int32 GetRawAddress(TSet<int32> Address) const { return GetAddress(Address); }
116116

117117
// Get reference to single float value with address
118-
float* GetCell(TArray<int32> Address);
119-
float GetValue(TArray<int32> Address) const;
118+
float* GetCell(TSet<int32> Address);
119+
float GetValue(TSet<int32> Address) const;
120120

121121
/* Create float array. Only works for tensor with one dimension.
122122
* Ex: auto p = FSimpleTorchTensor({ 4, 12 });
@@ -139,7 +139,7 @@ struct SIMPLEPYTORCH_API FSimpleTorchTensor
139139

140140
// Change dimensions.
141141
// Only keeps data if new overall size is equal to old sizse
142-
bool Reshape(TArray<int32> NewShape);
142+
bool Reshape(TSet<int32> NewShape);
143143

144144
// Create copy of this tensor
145145
FSimpleTorchTensor Detach();

Source/SimplePyTorch/SimplePyTorch.Build.cs

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// VR IK Body Plugin
2-
// (c) Yuri N Kalinin, 2021, ykasczc@gmail.com. All right reserved.
2+
// (c) Yuri N Kalinin, 2021-2022, ykasczc@gmail.com. All right reserved.
33

44
using UnrealBuildTool;
55
using System.IO;
@@ -60,27 +60,28 @@ public SimplePyTorch(ReadOnlyTargetRules Target) : base(Target)
6060

6161
if (Target.Platform == UnrealTargetPlatform.Win64)
6262
{
63-
// my PyTorch wrapper
64-
RuntimeDependencies.Add("$(BinaryOutputDir)/torchscript_wrapper.dll", Path.Combine(TorchBinariesPath, "torchscript_wrapper.dll"));
65-
if (!Target.bBuildEditor)
63+
// LibTorch libraries
64+
string[] DLLs = new string[]
6665
{
67-
// Copy DLLs to target packaged project
68-
RuntimeDependencies.Add("$(BinaryOutputDir)/asmjit.dll", Path.Combine(TorchBinariesPath, "asmjit.dll"));
69-
RuntimeDependencies.Add("$(BinaryOutputDir)/c10.dll", Path.Combine(TorchBinariesPath, "c10.dll"));
70-
RuntimeDependencies.Add("$(BinaryOutputDir)/caffe2_detectron_ops.dll", Path.Combine(TorchBinariesPath, "caffe2_detectron_ops.dll"));
71-
RuntimeDependencies.Add("$(BinaryOutputDir)/caffe2_module_test_dynamic.dll", Path.Combine(TorchBinariesPath, "caffe2_module_test_dynamic.dll"));
72-
RuntimeDependencies.Add("$(BinaryOutputDir)/fbgemm.dll", Path.Combine(TorchBinariesPath, "fbgemm.dll"));
73-
RuntimeDependencies.Add("$(BinaryOutputDir)/fbjni.dll", Path.Combine(TorchBinariesPath, "fbjni.dll"));
74-
RuntimeDependencies.Add("$(BinaryOutputDir)/libiomp5md.dll", Path.Combine(TorchBinariesPath, "libiomp5md.dll"));
75-
RuntimeDependencies.Add("$(BinaryOutputDir)/libiompstubs5md.dll", Path.Combine(TorchBinariesPath, "libiompstubs5md.dll"));
76-
RuntimeDependencies.Add("$(BinaryOutputDir)/pytorch_jni.dll", Path.Combine(TorchBinariesPath, "pytorch_jni.dll"));
77-
RuntimeDependencies.Add("$(BinaryOutputDir)/torch.dll", Path.Combine(TorchBinariesPath, "torch.dll"));
78-
RuntimeDependencies.Add("$(BinaryOutputDir)/torch_cpu.dll", Path.Combine(TorchBinariesPath, "torch_cpu.dll"));
79-
RuntimeDependencies.Add("$(BinaryOutputDir)/torch_global_deps.dll", Path.Combine(TorchBinariesPath, "torch_global_deps.dll"));
80-
RuntimeDependencies.Add("$(BinaryOutputDir)/uv.dll", Path.Combine(TorchBinariesPath, "uv.dll"));
66+
"asmjit.dll", "c10.dll", "caffe2_detectron_ops.dll", "caffe2_module_test_dynamic.dll", "fbgemm.dll", "fbjni.dll", "libiomp5md.dll",
67+
"libiompstubs5md.dll", "pytorch_jni.dll", "torch.dll", "torch_cpu.dll", "torch_global_deps.dll", "uv.dll"
68+
};
69+
70+
// copy all DLLs to the packaged build
71+
if (!Target.bBuildEditor && Target.Type == TargetType.Game)
72+
{
73+
string DllTargetDir = "$(ProjectDir)/Binaries/ThirdParty/PyTorch/";
74+
foreach (string DllName in DLLs)
75+
{
76+
PublicDelayLoadDLLs.Add(DllName);
77+
RuntimeDependencies.Add(Path.Combine(DllTargetDir, DllName), Path.Combine(TorchBinariesPath, DllName));
78+
}
79+
80+
// my PyTorch wrapper is loaded dynamically
81+
RuntimeDependencies.Add(Path.Combine(DllTargetDir, "torchscript_wrapper.dll"), Path.Combine(TorchBinariesPath, "torchscript_wrapper.dll"));
8182
// licenses
82-
RuntimeDependencies.Add("$(BinaryOutputDir)/LICENSE.txt", Path.Combine(TorchPath, "LICENSE.txt"));
83-
RuntimeDependencies.Add("$(BinaryOutputDir)/NOTICE.txt", Path.Combine(TorchPath, "NOTICE.txt"));
83+
RuntimeDependencies.Add(Path.Combine(DllTargetDir, "LICENSE.txt"), Path.Combine(TorchPath, "LICENSE.txt"), StagedFileType.NonUFS);
84+
RuntimeDependencies.Add(Path.Combine(DllTargetDir, "NOTICE.txt"), Path.Combine(TorchPath, "NOTICE.txt"), StagedFileType.NonUFS);
8485
}
8586
}
8687
}
23.5 KB
Binary file not shown.

0 commit comments

Comments
 (0)