Skip to content

Commit 02d9126

Browse files
authored
Multiple bug fixes
- support for multiple USimpleTorchModule per project - #issue-961466268
1 parent 8368010 commit 02d9126

6 files changed

Lines changed: 65 additions & 58 deletions

File tree

Source/SimplePyTorch/Private/SimplePyTorch.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ void FSimplePyTorchModule::StartupModule()
2020
if (ThisPlugin.IsValid())
2121
{
2222
FilePath = FPaths::ConvertRelativePathToFull(ThisPlugin->GetBaseDir());
23+
24+
FString PluginBinariesDir = FilePath / TEXT("Source/ThirdParty/pytorch") / szBinaries / szPlatform;
25+
UE_LOG(LogTemp, Log, TEXT("PyTorch third-party dlls directory: %s"), *PluginBinariesDir);
26+
FPlatformProcess::PushDllDirectory(*PluginBinariesDir);
2327
}
2428
#else
2529
FilePath = FPaths::ConvertRelativePathToFull(FPaths::ProjectDir());

Source/SimplePyTorch/Private/SimpleTorchModule.cpp

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <vector>
99

1010
USimpleTorchModule::USimpleTorchModule()
11+
: ModelId(INDEX_NONE)
1112
{
1213
Buffer = NULL;
1314
BufferDims = NULL;
@@ -54,7 +55,8 @@ bool USimpleTorchModule::LoadTorchScriptModel(FString FileName)
5455
}
5556
BufferDims = new int[16];
5657

57-
return Module.FuncTSW_LoadScriptModel(TCHAR_TO_ANSI(*FileName));
58+
ModelId = Module.FuncTSW_LoadScriptModel(TCHAR_TO_ANSI(*FileName));
59+
bResult = (ModelId != INDEX_NONE);
5860
}
5961
else
6062
{
@@ -71,7 +73,7 @@ bool USimpleTorchModule::IsTorchModelLoaded() const
7173
bool bResult = false;
7274
if (Module.bDllLoaded)
7375
{
74-
return Module.FuncTSW_CheckModel();
76+
return Module.FuncTSW_CheckModel(ModelId);
7577
}
7678

7779
return bResult;
@@ -89,7 +91,7 @@ bool USimpleTorchModule::ExecuteModelMethod(const FString& MethodName, const FSi
8991
bool bResult = false;
9092
if (Module.bDllLoaded && Buffer != NULL && OutData.IsDataOwner())
9193
{
92-
TArray<int> InDims = InData.GetDimensions().Array();
94+
TArray<int> InDims = InData.GetDimensions();
9395

9496
float* pOutData = OutData.IsValid()
9597
? OutData.GetRawData()
@@ -98,21 +100,21 @@ bool USimpleTorchModule::ExecuteModelMethod(const FString& MethodName, const FSi
98100
int OutDimsCount = 0;
99101
if (MethodName == TEXT("forward"))
100102
{
101-
Module.FuncTSW_ForwardPass_Def(InData.GetRawData(), InDims.GetData(), InDims.Num(),
103+
Module.FuncTSW_ForwardPass_Def(ModelId, InData.GetRawData(), InDims.GetData(), InDims.Num(),
102104
pOutData, BufferDims, &OutDimsCount);
103105
}
104106
else
105107
{
106-
Module.FuncTSW_Execute_Def(TCHAR_TO_ANSI(*MethodName), InData.GetRawData(), InDims.GetData(), InDims.Num(),
108+
Module.FuncTSW_Execute_Def(ModelId, TCHAR_TO_ANSI(*MethodName), InData.GetRawData(), InDims.GetData(), InDims.Num(),
107109
pOutData, BufferDims, &OutDimsCount);
108110
}
109111

110112
bResult = (OutDimsCount > 0);
111113
if (bResult)
112114
{
113-
TArray<int32> OldOutDims = OutData.GetDimensions().Array();
115+
TArray<int32> OldOutDims = OutData.GetDimensions();
114116
bool bOutTensorMatches = (OutDimsCount == OutData.GetDimensions().Num());
115-
TSet<int32> NewOutDims;
117+
TArray<int32> NewOutDims;
116118

117119
int32 Length = 1;
118120
for (int i = 0; i < OutDimsCount; i++)
@@ -195,14 +197,14 @@ void FSimpleTorchTensor::Cleanup()
195197
}
196198
}
197199

198-
int32 FSimpleTorchTensor::GetAddress(TSet<int32> Address) const
200+
int32 FSimpleTorchTensor::GetAddress(TArray<int32> Address) const
199201
{
200202
if (!Data)
201203
{
202204
return INDEX_NONE;
203205
}
204206

205-
TArray<int32> AddrArray = Address.Array();
207+
TArray<int32> AddrArray = Address;
206208
int32 Addr = 0;
207209
for (int32 i = 0; i < AddressMultipliersCache.Num(); i++)
208210
{
@@ -218,7 +220,7 @@ int32 FSimpleTorchTensor::GetAddress(TSet<int32> Address) const
218220
return Addr;
219221
}
220222

221-
bool FSimpleTorchTensor::Create(TSet<int32> TensorDimensions)
223+
bool FSimpleTorchTensor::Create(TArray<int32> TensorDimensions)
222224
{
223225
if (Data)
224226
{
@@ -233,7 +235,7 @@ bool FSimpleTorchTensor::Create(TSet<int32> TensorDimensions)
233235

234236
if (DataSize == 0) return false;
235237

236-
Dimensions = TensorDimensions.Array();
238+
Dimensions = TensorDimensions;
237239
InitAddressSpace();
238240

239241
Data = new float[DataSize];
@@ -242,7 +244,7 @@ bool FSimpleTorchTensor::Create(TSet<int32> TensorDimensions)
242244
return true;
243245
}
244246

245-
bool FSimpleTorchTensor::CreateAsChild(FSimpleTorchTensor* Parent, TSet<int32> Address)
247+
bool FSimpleTorchTensor::CreateAsChild(FSimpleTorchTensor* Parent, TArray<int32> Address)
246248
{
247249
bDataOwner = false;
248250

@@ -296,9 +298,9 @@ bool FSimpleTorchTensor::CreateAsChild(FSimpleTorchTensor* Parent, TSet<int32> A
296298
return true;
297299
}
298300

299-
TSet<int32> FSimpleTorchTensor::GetDimensions() const
301+
TArray<int32> FSimpleTorchTensor::GetDimensions() const
300302
{
301-
TSet<int32> t;
303+
TArray<int32> t;
302304
for (const auto& d : Dimensions)
303305
t.Add(d);
304306

@@ -329,13 +331,13 @@ float* FSimpleTorchTensor::GetRawData(int32* Size) const
329331
}
330332
}
331333

332-
float* FSimpleTorchTensor::GetCell(TSet<int32> Address)
334+
float* FSimpleTorchTensor::GetCell(TArray<int32> Address)
333335
{
334336
int32 Addr = GetAddress(Address);
335337
return Addr == INDEX_NONE ? NULL : &Data[Addr];
336338
}
337339

338-
float FSimpleTorchTensor::GetValue(TSet<int32> Address) const
340+
float FSimpleTorchTensor::GetValue(TArray<int32> Address) const
339341
{
340342
int32 Addr = GetAddress(Address);
341343
return Addr == INDEX_NONE ? 0 : Data[Addr];
@@ -371,7 +373,7 @@ bool FSimpleTorchTensor::FromArray(const TArray<float>& InData)
371373
return false;
372374
}
373375

374-
bool FSimpleTorchTensor::Reshape(TSet<int32> NewShape)
376+
bool FSimpleTorchTensor::Reshape(TArray<int32> NewShape)
375377
{
376378
if (NewShape.Num() == 0)
377379
{
@@ -390,7 +392,7 @@ bool FSimpleTorchTensor::Reshape(TSet<int32> NewShape)
390392

391393
if (NewDataSize == DataSize)
392394
{
393-
Dimensions = NewShape.Array();
395+
Dimensions = NewShape;
394396
InitAddressSpace();
395397
}
396398
else
@@ -400,7 +402,7 @@ bool FSimpleTorchTensor::Reshape(TSet<int32> NewShape)
400402
Cleanup();
401403

402404
DataSize = NewDataSize;
403-
Dimensions = NewShape.Array();
405+
Dimensions = NewShape;
404406
Data = new float[DataSize];
405407

406408
InitAddressSpace();
@@ -422,7 +424,7 @@ FSimpleTorchTensor FSimpleTorchTensor::Detach()
422424
return FSimpleTorchTensor();
423425
}
424426

425-
TSet<int32> Dims;
427+
TArray<int32> Dims;
426428
for (const auto& Val : Dimensions) Dims.Add(Val);
427429

428430
FSimpleTorchTensor ret = FSimpleTorchTensor(Dims);

Source/SimplePyTorch/Public/SimplePyTorch.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88
#include <vector>
99

1010
// BEGIN C HEADER
11-
typedef bool(*__TSW_LoadScriptModel)(char* FileName);
12-
typedef bool(*__TSW_CheckModel)();
13-
typedef bool(*__TSW_Forward1d)(const std::vector<float> InData, int InDimensions);
14-
typedef bool(*__TSW_ForwardTensor)(const float* InData, const int InDimensions[3], float*& OutData, int* Size1, int* Size2, int* Size3);
15-
typedef bool(*__TSW_ForwardPass_Def)(const float* InData, const int* InDimensions, int nDimensionsCount, float*& OutData, int*& OutDimensions, int* OutDimensionsCount);
16-
typedef bool(*__TSW_Execute_Def)(char* FunctionName, const float* InData, const int* InDimensions, int nDimensionsCount, float*& OutData, int*& OutDimensions, int* OutDimensionsCount);
11+
typedef int(*__TSW_LoadScriptModel)(char* FileName);
12+
typedef bool(*__TSW_CheckModel)(int ModelId);
13+
typedef bool(*__TSW_Forward1d)(int ModelId, const std::vector<float> InData, int InDimensions);
14+
typedef bool(*__TSW_ForwardTensor)(int ModelId, const float* InData, const int InDimensions[3], float*& OutData, int* Size1, int* Size2, int* Size3);
15+
typedef bool(*__TSW_ForwardPass_Def)(int ModelId, const float* InData, const int* InDimensions, int nDimensionsCount, float*& OutData, int*& OutDimensions, int* OutDimensionsCount);
16+
typedef bool(*__TSW_Execute_Def)(int ModelId, char* FunctionName, const float* InData, const int* InDimensions, int nDimensionsCount, float*& OutData, int*& OutDimensions, int* OutDimensionsCount);
1717
// END HEADER
1818

1919
class FSimplePyTorchModule : public IModuleInterface

Source/SimplePyTorch/Public/SimpleTorchModule.h

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ struct SIMPLEPYTORCH_API FSimpleTorchTensor
5353
void InitAddressSpace();
5454

5555
/** Get flat address in Data from multidimensional address */
56-
int32 GetAddress(TSet<int32> Address) const;
56+
int32 GetAddress(TArray<int32> Address) const;
5757
public:
5858

5959
FSimpleTorchTensor()
@@ -62,15 +62,15 @@ struct SIMPLEPYTORCH_API FSimpleTorchTensor
6262
, DataSize(0)
6363
, ParentTensor(nullptr)
6464
{}
65-
FSimpleTorchTensor(FSimpleTorchTensor* Parent, TSet<int32> SubAddress)
65+
FSimpleTorchTensor(FSimpleTorchTensor* Parent, TArray<int32> SubAddress)
6666
: Data(NULL)
6767
, bDataOwner(true)
6868
, DataSize(0)
6969
, ParentTensor(nullptr)
7070
{
7171
CreateAsChild(Parent, SubAddress);
7272
}
73-
FSimpleTorchTensor(TSet<int32> Dimensions)
73+
FSimpleTorchTensor(TArray<int32> Dimensions)
7474
: Data(NULL)
7575
, bDataOwner(true)
7676
, DataSize(0)
@@ -87,10 +87,10 @@ struct SIMPLEPYTORCH_API FSimpleTorchTensor
8787
void Cleanup();
8888

8989
// Set dimensions and allocate memory
90-
bool Create(TSet<int32> TensorDimensions);
90+
bool Create(TArray<int32> TensorDimensions);
9191

9292
// Create tensor as a subtensor in another tensor (share the same memory)
93-
bool CreateAsChild(FSimpleTorchTensor* Parent, TSet<int32> Address);
93+
bool CreateAsChild(FSimpleTorchTensor* Parent, TArray<int32> Address);
9494

9595
// Is tensor initialized?
9696
bool IsValid() const { return Data != NULL; }
@@ -99,7 +99,7 @@ struct SIMPLEPYTORCH_API FSimpleTorchTensor
9999
bool IsDataOwner() const { return bDataOwner; }
100100

101101
// Get current tensor dimensions
102-
TSet<int32> GetDimensions() const;
102+
TArray<int32> GetDimensions() const;
103103

104104
// Get number of itemes in flat array
105105
int32 GetDataSize() const { return DataSize; }
@@ -112,11 +112,11 @@ struct SIMPLEPYTORCH_API FSimpleTorchTensor
112112
float* GetRawData() const { return Data; }
113113

114114
// Convert multidimensional address to flat address
115-
int32 GetRawAddress(TSet<int32> Address) const { return GetAddress(Address); }
115+
int32 GetRawAddress(TArray<int32> Address) const { return GetAddress(Address); }
116116

117117
// Get reference to single float value with address
118-
float* GetCell(TSet<int32> Address);
119-
float GetValue(TSet<int32> Address) const;
118+
float* GetCell(TArray<int32> Address);
119+
float GetValue(TArray<int32> Address) const;
120120

121121
/* Create float array. Only works for tensor with one dimension.
122122
* Ex: auto p = FSimpleTorchTensor({ 4, 12 });
@@ -139,7 +139,7 @@ struct SIMPLEPYTORCH_API FSimpleTorchTensor
139139

140140
// Change dimensions.
141141
// Only keeps data if new overall size is equal to old sizse
142-
bool Reshape(TSet<int32> NewShape);
142+
bool Reshape(TArray<int32> NewShape);
143143

144144
// Create copy of this tensor
145145
FSimpleTorchTensor Detach();
@@ -190,6 +190,9 @@ class SIMPLEPYTORCH_API USimpleTorchModule : public UObject
190190
bool ExecuteModelMethod(const FString& MethodName, const FSimpleTorchTensor& InData, FSimpleTorchTensor& OutData);
191191

192192
private:
193+
/** Identified of the loaded torch script model */
194+
int32 ModelId;
195+
193196
/** Buffer size (default) */
194197
int32 BufferSize;
195198

Source/SimplePyTorch/SimplePyTorch.Build.cs

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,12 @@ public SimplePyTorch(ReadOnlyTargetRules Target) : base(Target)
2222

2323
PublicIncludePaths.AddRange(
2424
new string[] {
25-
// ... add public include paths required here ...
2625
}
2726
);
2827

2928

3029
PrivateIncludePaths.AddRange(
3130
new string[] {
32-
// ... add other private include paths required here ...
3331
}
3432
);
3533

@@ -38,7 +36,6 @@ public SimplePyTorch(ReadOnlyTargetRules Target) : base(Target)
3836
new string[]
3937
{
4038
"Core",
41-
// ... add other public dependencies that you statically link with here ...
4239
}
4340
);
4441

@@ -51,39 +48,40 @@ public SimplePyTorch(ReadOnlyTargetRules Target) : base(Target)
5148
"Slate",
5249
"SlateCore",
5350
"Projects"
54-
// ... add private dependencies that you statically link with here ...
5551
}
5652
);
5753

5854

5955
DynamicallyLoadedModuleNames.AddRange(
6056
new string[]
6157
{
62-
// ... add any modules that your module loads dynamically here ...
6358
}
6459
);
6560

6661
if (Target.Platform == UnrealTargetPlatform.Win64)
6762
{
68-
// Copy DLLs to target packaged project
69-
RuntimeDependencies.Add("$(BinaryOutputDir)/asmjit.dll", Path.Combine(TorchBinariesPath, "asmjit.dll"));
70-
RuntimeDependencies.Add("$(BinaryOutputDir)/c10.dll", Path.Combine(TorchBinariesPath, "c10.dll"));
71-
RuntimeDependencies.Add("$(BinaryOutputDir)/caffe2_detectron_ops.dll", Path.Combine(TorchBinariesPath, "caffe2_detectron_ops.dll"));
72-
RuntimeDependencies.Add("$(BinaryOutputDir)/caffe2_module_test_dynamic.dll", Path.Combine(TorchBinariesPath, "caffe2_module_test_dynamic.dll"));
73-
RuntimeDependencies.Add("$(BinaryOutputDir)/fbgemm.dll", Path.Combine(TorchBinariesPath, "fbgemm.dll"));
74-
RuntimeDependencies.Add("$(BinaryOutputDir)/fbjni.dll", Path.Combine(TorchBinariesPath, "fbjni.dll"));
75-
RuntimeDependencies.Add("$(BinaryOutputDir)/libiomp5md.dll", Path.Combine(TorchBinariesPath, "libiomp5md.dll"));
76-
RuntimeDependencies.Add("$(BinaryOutputDir)/libiompstubs5md.dll", Path.Combine(TorchBinariesPath, "libiompstubs5md.dll"));
77-
RuntimeDependencies.Add("$(BinaryOutputDir)/pytorch_jni.dll", Path.Combine(TorchBinariesPath, "pytorch_jni.dll"));
78-
RuntimeDependencies.Add("$(BinaryOutputDir)/torch.dll", Path.Combine(TorchBinariesPath, "torch.dll"));
79-
RuntimeDependencies.Add("$(BinaryOutputDir)/torch_cpu.dll", Path.Combine(TorchBinariesPath, "torch_cpu.dll"));
80-
RuntimeDependencies.Add("$(BinaryOutputDir)/torch_global_deps.dll", Path.Combine(TorchBinariesPath, "torch_global_deps.dll"));
81-
RuntimeDependencies.Add("$(BinaryOutputDir)/uv.dll", Path.Combine(TorchBinariesPath, "uv.dll"));
8263
// my PyTorch wrapper
8364
RuntimeDependencies.Add("$(BinaryOutputDir)/torchscript_wrapper.dll", Path.Combine(TorchBinariesPath, "torchscript_wrapper.dll"));
84-
// licenses
85-
RuntimeDependencies.Add("$(BinaryOutputDir)/LICENSE.txt", Path.Combine(TorchPath, "LICENSE.txt"));
86-
RuntimeDependencies.Add("$(BinaryOutputDir)/NOTICE.txt", Path.Combine(TorchPath, "NOTICE.txt"));
65+
if (!Target.bBuildEditor)
66+
{
67+
// Copy DLLs to target packaged project
68+
RuntimeDependencies.Add("$(BinaryOutputDir)/asmjit.dll", Path.Combine(TorchBinariesPath, "asmjit.dll"));
69+
RuntimeDependencies.Add("$(BinaryOutputDir)/c10.dll", Path.Combine(TorchBinariesPath, "c10.dll"));
70+
RuntimeDependencies.Add("$(BinaryOutputDir)/caffe2_detectron_ops.dll", Path.Combine(TorchBinariesPath, "caffe2_detectron_ops.dll"));
71+
RuntimeDependencies.Add("$(BinaryOutputDir)/caffe2_module_test_dynamic.dll", Path.Combine(TorchBinariesPath, "caffe2_module_test_dynamic.dll"));
72+
RuntimeDependencies.Add("$(BinaryOutputDir)/fbgemm.dll", Path.Combine(TorchBinariesPath, "fbgemm.dll"));
73+
RuntimeDependencies.Add("$(BinaryOutputDir)/fbjni.dll", Path.Combine(TorchBinariesPath, "fbjni.dll"));
74+
RuntimeDependencies.Add("$(BinaryOutputDir)/libiomp5md.dll", Path.Combine(TorchBinariesPath, "libiomp5md.dll"));
75+
RuntimeDependencies.Add("$(BinaryOutputDir)/libiompstubs5md.dll", Path.Combine(TorchBinariesPath, "libiompstubs5md.dll"));
76+
RuntimeDependencies.Add("$(BinaryOutputDir)/pytorch_jni.dll", Path.Combine(TorchBinariesPath, "pytorch_jni.dll"));
77+
RuntimeDependencies.Add("$(BinaryOutputDir)/torch.dll", Path.Combine(TorchBinariesPath, "torch.dll"));
78+
RuntimeDependencies.Add("$(BinaryOutputDir)/torch_cpu.dll", Path.Combine(TorchBinariesPath, "torch_cpu.dll"));
79+
RuntimeDependencies.Add("$(BinaryOutputDir)/torch_global_deps.dll", Path.Combine(TorchBinariesPath, "torch_global_deps.dll"));
80+
RuntimeDependencies.Add("$(BinaryOutputDir)/uv.dll", Path.Combine(TorchBinariesPath, "uv.dll"));
81+
// licenses
82+
RuntimeDependencies.Add("$(BinaryOutputDir)/LICENSE.txt", Path.Combine(TorchPath, "LICENSE.txt"));
83+
RuntimeDependencies.Add("$(BinaryOutputDir)/NOTICE.txt", Path.Combine(TorchPath, "NOTICE.txt"));
84+
}
8785
}
8886
}
8987
}
Binary file not shown.

0 commit comments

Comments
 (0)