Skip to content

Commit dde257a

Browse files
CopilotLeftofZen
andauthored
fix: address download review feedback
Agent-Logs-Url: https://github.com/OpenLoco/ObjectEditor/sessions/8fca8cd2-b4dd-4d4e-928a-37cfb8017f82 Co-authored-by: LeftofZen <7483209+LeftofZen@users.noreply.github.com>
1 parent 7b3b841 commit dde257a

6 files changed

Lines changed: 285 additions & 28 deletions

File tree

Gui/ViewModels/FolderTreeViewModel.cs

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
using Gui.Models;
1616
using Gui.ViewModels.Filters;
1717
using Index;
18+
using MsBox.Avalonia;
19+
using MsBox.Avalonia.Dto;
20+
using MsBox.Avalonia.Enums;
1821
using ReactiveUI;
1922
using ReactiveUI.Fody.Helpers;
2023
using System;
@@ -824,8 +827,29 @@ async Task DownloadOnlinePackAsync(OnlineItemPackBrowseResult pack)
824827
}
825828

826829
var safePackName = Path.GetInvalidFileNameChars().Aggregate(pack.Name, (current, c) => current.Replace(c, '_'));
827-
var filename = Path.Combine(EditorContext.Settings.DownloadFolder, $"{safePackName}.zip");
828-
await File.WriteAllBytesAsync(filename, fileBytes);
829-
EditorContext.Logger.Info($"Downloaded pack \"{pack.Name}\" to \"{filename}\"");
830+
var filename = Path.Combine(EditorContext.Settings.DownloadFolder, $"{safePackName}-{pack.Id}.zip");
831+
try
832+
{
833+
await using var outputStream = new FileStream(filename, FileMode.CreateNew, FileAccess.Write, FileShare.None, 4096, useAsync: true);
834+
await outputStream.WriteAsync(fileBytes);
835+
EditorContext.Logger.Info($"Downloaded pack \"{pack.Name}\" to \"{filename}\"");
836+
}
837+
catch (IOException ex)
838+
{
839+
EditorContext.Logger.Error($"Failed to download pack \"{pack.Name}\" to \"{filename}\"", ex);
840+
var box = MessageBoxManager.GetMessageBoxStandard(new MessageBoxStandardParams
841+
{
842+
ContentTitle = "Download failed",
843+
ContentMessage = $"Could not create:\n{filename}\n\nA file may already exist or be locked by another process.",
844+
ButtonDefinitions = ButtonEnum.Ok,
845+
Icon = Icon.Warning,
846+
WindowStartupLocation = WindowStartupLocation.CenterOwner,
847+
Topmost = true,
848+
ShowInCenter = true,
849+
SizeToContent = SizeToContent.WidthAndHeight,
850+
MinHeight = 170,
851+
});
852+
_ = await box.ShowAsync();
853+
}
830854
}
831855
}

ObjectService/RouteHandlers/RouteHelpers.cs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,42 @@ namespace ObjectService.RouteHandlers;
22

33
public static class RouteHelpers
44
{
5+
static readonly char[] PathSeparators = [Path.DirectorySeparatorChar, Path.AltDirectorySeparatorChar];
6+
57
public static string MakeNicePlural(string name)
68
=> $"{name.Replace("RouteHandler", string.Empty)}s";
9+
10+
public static bool TryGetSafeRelativePathUnderRoot(string rootPath, string? relativePath, out string fullPath, out string normalizedRelativePath)
11+
{
12+
fullPath = string.Empty;
13+
normalizedRelativePath = string.Empty;
14+
15+
if (string.IsNullOrWhiteSpace(relativePath) || Path.IsPathRooted(relativePath))
16+
{
17+
return false;
18+
}
19+
20+
var segments = relativePath.Split(PathSeparators, StringSplitOptions.None);
21+
if (segments.Any(x => string.IsNullOrEmpty(x) || x is "." or ".."))
22+
{
23+
return false;
24+
}
25+
26+
var rootFullPath = Path.GetFullPath(rootPath);
27+
if (!rootFullPath.EndsWith(Path.DirectorySeparatorChar))
28+
{
29+
rootFullPath += Path.DirectorySeparatorChar;
30+
}
31+
32+
var combinedPath = Path.Combine(rootFullPath, relativePath);
33+
var candidateFullPath = Path.GetFullPath(combinedPath);
34+
if (!candidateFullPath.StartsWith(rootFullPath, StringComparison.OrdinalIgnoreCase))
35+
{
36+
return false;
37+
}
38+
39+
fullPath = candidateFullPath;
40+
normalizedRelativePath = relativePath.Replace('\\', '/');
41+
return true;
42+
}
743
}

ObjectService/RouteHandlers/TableHandlers/ObjectPackRouteHandler.cs

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,14 @@ public static async Task<IResult> GetObjectPackFileAsync([FromRoute] UniqueObjec
6969
}
7070

7171
var sfm = sp.GetRequiredService<ServerFolderManager>();
72-
var zipStream = new MemoryStream();
72+
var tempZipPath = Path.Combine(Path.GetTempPath(), $"{Guid.NewGuid():N}.zip");
73+
var zipStream = new FileStream(
74+
tempZipPath,
75+
FileMode.Create,
76+
FileAccess.ReadWrite,
77+
FileShare.None,
78+
bufferSize: 4096,
79+
options: FileOptions.Asynchronous | FileOptions.DeleteOnClose);
7380

7481
using (var archive = new ZipArchive(zipStream, ZipArchiveMode.Create, leaveOpen: true))
7582
{
@@ -82,19 +89,25 @@ public static async Task<IResult> GetObjectPackFileAsync([FromRoute] UniqueObjec
8289

8390
foreach (var dat in obj.DatObjects)
8491
{
85-
if (!sfm.ObjectIndex.TryFind((dat.DatName, dat.DatChecksum), out var entry) || entry == null || string.IsNullOrEmpty(entry.FileName))
92+
if (!sfm.ObjectIndex.TryFind((dat.DatName, dat.DatChecksum), out var entry) || entry == null)
8693
{
8794
continue;
8895
}
8996

90-
var filePath = Path.Combine(sfm.ObjectsFolder, entry.FileName);
91-
if (!File.Exists(filePath))
97+
if (!RouteHelpers.TryGetSafeRelativePathUnderRoot(sfm.ObjectsFolder, entry.FileName, out var fullFilePath, out var entryName))
9298
{
9399
continue;
94100
}
95101

96-
// Use the relative path from the objects folder as the entry name to avoid duplicate filename collisions
97-
archive.CreateEntryFromFile(filePath, entry.FileName.Replace('\\', '/'));
102+
if (!File.Exists(fullFilePath))
103+
{
104+
continue;
105+
}
106+
107+
await using var fileStream = File.OpenRead(fullFilePath);
108+
var zipEntry = archive.CreateEntry(entryName, CompressionLevel.Optimal);
109+
await using var entryStream = zipEntry.Open();
110+
await fileStream.CopyToAsync(entryStream);
98111
}
99112
}
100113
}

ObjectService/RouteHandlers/TableHandlers/SC5FilePackRouteHandler.cs

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,20 +67,33 @@ public static async Task<IResult> GetSC5FilePackFileAsync([FromRoute] UniqueObje
6767
}
6868

6969
var sfm = sp.GetRequiredService<ServerFolderManager>();
70-
var zipStream = new MemoryStream();
70+
var tempZipPath = Path.Combine(Path.GetTempPath(), $"{Guid.NewGuid():N}.zip");
71+
var zipStream = new FileStream(
72+
tempZipPath,
73+
FileMode.Create,
74+
FileAccess.ReadWrite,
75+
FileShare.None,
76+
bufferSize: 4096,
77+
options: FileOptions.Asynchronous | FileOptions.DeleteOnClose);
7178

7279
using (var archive = new ZipArchive(zipStream, ZipArchiveMode.Create, leaveOpen: true))
7380
{
7481
foreach (var sc5File in pack.SC5Files)
7582
{
76-
var filePath = Path.Combine(sfm.ScenariosFolder, sc5File.Name);
77-
if (!File.Exists(filePath))
83+
if (!RouteHelpers.TryGetSafeRelativePathUnderRoot(sfm.ScenariosFolder, sc5File.Name, out var fullFilePath, out var entryName))
7884
{
7985
continue;
8086
}
8187

82-
// Use the relative path as the entry name to avoid duplicate filename collisions
83-
archive.CreateEntryFromFile(filePath, sc5File.Name.Replace('\\', '/'));
88+
if (!File.Exists(fullFilePath))
89+
{
90+
continue;
91+
}
92+
93+
await using var fileStream = File.OpenRead(fullFilePath);
94+
var entry = archive.CreateEntry(entryName, CompressionLevel.Optimal);
95+
await using var entryStream = entry.Open();
96+
await fileStream.CopyToAsync(entryStream);
8497
}
8598
}
8699

ObjectService/RouteHandlers/TableHandlers/ScenarioRouteHandler.cs

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ public static void MapAdditionalRoutes(IEndpointRouteBuilder parentRoute)
2323
}
2424

2525
static string[] GetSortedScenarioFiles(string scenarioFolder)
26-
=> [.. Directory.GetFiles(scenarioFolder, "*.SC5", SearchOption.AllDirectories).OrderBy(x => x)];
26+
=> [.. Directory
27+
.GetFiles(scenarioFolder, "*.SC5", SearchOption.AllDirectories)
28+
.OrderBy(x => Path.GetRelativePath(scenarioFolder, x), StringComparer.Ordinal)];
2729

2830
static async Task<IResult> ListAsync([FromServices] IServiceProvider sp)
2931
=> await Task.Run(() =>
@@ -36,21 +38,20 @@ static async Task<IResult> ListAsync([FromServices] IServiceProvider sp)
3638
return Results.Ok(filenames.ToList());
3739
});
3840

39-
static async Task<IResult> GetScenarioFileAsync([FromRoute] UniqueObjectId id, [FromServices] IServiceProvider sp)
40-
=> await Task.Run(() =>
41-
{
42-
var sfm = sp.GetRequiredService<ServerFolderManager>();
43-
var files = GetSortedScenarioFiles(sfm.ScenariosFolder);
41+
static Task<IResult> GetScenarioFileAsync([FromRoute] UniqueObjectId id, [FromServices] IServiceProvider sp)
42+
{
43+
var sfm = sp.GetRequiredService<ServerFolderManager>();
44+
var files = GetSortedScenarioFiles(sfm.ScenariosFolder);
4445

45-
if (id >= (ulong)files.Length)
46-
{
47-
return Results.NotFound();
48-
}
46+
if (id >= (ulong)files.Length)
47+
{
48+
return Task.FromResult<IResult>(Results.NotFound());
49+
}
4950

50-
var path = files[(int)id];
51-
const string contentType = "application/octet-stream";
52-
return Results.File(path, contentType, Path.GetFileName(path));
53-
});
51+
var path = files[(int)id];
52+
const string contentType = "application/octet-stream";
53+
return Task.FromResult<IResult>(Results.File(path, contentType, Path.GetFileName(path)));
54+
}
5455

5556
static async Task<IResult> CreateAsync(DtoScenarioEntry request)
5657
=> await Task.Run(() => Results.Problem(statusCode: StatusCodes.Status501NotImplemented));
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
using Definitions;
2+
using Definitions.Database;
3+
using Definitions.ObjectModels.Types;
4+
using Definitions.Web;
5+
using Index;
6+
using Microsoft.Extensions.Configuration;
7+
using Microsoft.Extensions.DependencyInjection;
8+
using NUnit.Framework;
9+
using System.IO.Compression;
10+
11+
namespace ObjectService.Tests.Integration;
12+
13+
[TestFixture]
14+
public class DownloadRoutesTest
15+
{
16+
HttpClient? httpClient;
17+
TestWebApplicationFactory<Program>? testWebAppFactory;
18+
19+
[SetUp]
20+
public void SetUp()
21+
{
22+
testWebAppFactory = new TestWebApplicationFactory<Program>();
23+
httpClient = testWebAppFactory.CreateClient();
24+
}
25+
26+
[TearDown]
27+
public void TearDown()
28+
{
29+
httpClient?.Dispose();
30+
testWebAppFactory?.Dispose();
31+
}
32+
33+
[Test]
34+
public async Task GetScenarioFileAsync_ReturnsFileMatchingSortedListOrder()
35+
{
36+
using var scope = testWebAppFactory!.Services.CreateScope();
37+
var sfm = scope.ServiceProvider.GetRequiredService<ServerFolderManager>();
38+
39+
var alphaRelativePath = Path.Combine(ServerFolderManager.CustomFolderName, "alpha.SC5");
40+
var zuluRelativePath = Path.Combine(ServerFolderManager.CustomFolderName, "zulu.SC5");
41+
var alphaPath = Path.Combine(sfm.ScenariosFolder, alphaRelativePath);
42+
var zuluPath = Path.Combine(sfm.ScenariosFolder, zuluRelativePath);
43+
44+
await File.WriteAllBytesAsync(zuluPath, [9, 9, 9]);
45+
await File.WriteAllBytesAsync(alphaPath, [1, 2, 3]);
46+
47+
var list = await Client.GetListAsync<Definitions.DTO.DtoScenarioEntry>(httpClient!, Client.ScenariosEndpointGroup);
48+
var firstScenario = list.First();
49+
50+
using var response = await httpClient!.GetAsync($"{RoutesV2.Prefix}{RoutesV2.Scenarios}/{firstScenario.Id}{RoutesV2.File}");
51+
var bytes = await response.Content.ReadAsByteArrayAsync();
52+
53+
using (Assert.EnterMultipleScope())
54+
{
55+
Assert.That(response.IsSuccessStatusCode, Is.True);
56+
Assert.That(firstScenario.Name, Is.EqualTo(alphaRelativePath));
57+
Assert.That(bytes, Is.EqualTo(new byte[] { 1, 2, 3 }));
58+
}
59+
}
60+
61+
[Test]
62+
public async Task GetSC5FilePackFileAsync_ReturnsZipWithOnlySafeScenarioEntries()
63+
{
64+
using var scope = testWebAppFactory!.Services.CreateScope();
65+
var sp = scope.ServiceProvider;
66+
var db = sp.GetRequiredService<LocoDbContext>();
67+
var sfm = sp.GetRequiredService<ServerFolderManager>();
68+
var config = sp.GetRequiredService<IConfiguration>();
69+
var rootFolder = config["ObjectService:RootFolder"];
70+
ArgumentNullException.ThrowIfNull(rootFolder);
71+
72+
var safeRelativePath = Path.Combine(ServerFolderManager.CustomFolderName, "pack-safe.SC5");
73+
var safePath = Path.Combine(sfm.ScenariosFolder, safeRelativePath);
74+
await File.WriteAllBytesAsync(safePath, [1, 2, 3, 4]);
75+
76+
var outsidePath = Path.Combine(rootFolder, "outside.SC5");
77+
await File.WriteAllBytesAsync(outsidePath, [7, 7, 7]);
78+
79+
var pack = new TblSC5FilePack
80+
{
81+
Id = 1,
82+
Name = "Scenario Pack",
83+
SC5Files =
84+
[
85+
new TblSC5File { Id = 1, Name = safeRelativePath },
86+
new TblSC5File { Id = 2, Name = Path.Combine("..", "outside.SC5") },
87+
],
88+
};
89+
90+
_ = await db.SC5FilePacks.AddAsync(pack);
91+
_ = await db.SaveChangesAsync();
92+
93+
using var response = await httpClient!.GetAsync($"{RoutesV2.Prefix}{RoutesV2.SC5FilePacks}/{pack.Id}{RoutesV2.File}");
94+
var bytes = await response.Content.ReadAsByteArrayAsync();
95+
using var archive = new ZipArchive(new MemoryStream(bytes), ZipArchiveMode.Read);
96+
97+
using (Assert.EnterMultipleScope())
98+
{
99+
Assert.That(response.IsSuccessStatusCode, Is.True);
100+
Assert.That(response.Content.Headers.ContentType?.MediaType, Is.EqualTo("application/zip"));
101+
Assert.That(archive.Entries.Select(x => x.FullName), Is.EqualTo(new[] { safeRelativePath.Replace('\\', '/') }));
102+
await using var entryStream = archive.Entries.Single().Open();
103+
using var entryMemoryStream = new MemoryStream();
104+
await entryStream.CopyToAsync(entryMemoryStream);
105+
Assert.That(entryMemoryStream.ToArray(), Is.EqualTo(new byte[] { 1, 2, 3, 4 }));
106+
}
107+
}
108+
109+
[Test]
110+
public async Task GetObjectPackFileAsync_ReturnsZipWithOnlySafeIndexedObjectEntries()
111+
{
112+
using var scope = testWebAppFactory!.Services.CreateScope();
113+
var sp = scope.ServiceProvider;
114+
var db = sp.GetRequiredService<LocoDbContext>();
115+
var sfm = sp.GetRequiredService<ServerFolderManager>();
116+
var config = sp.GetRequiredService<IConfiguration>();
117+
var rootFolder = config["ObjectService:RootFolder"];
118+
ArgumentNullException.ThrowIfNull(rootFolder);
119+
120+
var safeRelativePath = Path.Combine(ServerFolderManager.CustomFolderName, "safe-object.dat");
121+
var safePath = Path.Combine(sfm.ObjectsFolder, safeRelativePath);
122+
await File.WriteAllBytesAsync(safePath, [4, 3, 2, 1]);
123+
124+
var outsidePath = Path.Combine(rootFolder, "outside.dat");
125+
await File.WriteAllBytesAsync(outsidePath, [8, 8, 8]);
126+
127+
sfm.ObjectIndex.Objects.Add(new ObjectIndexEntry("SAFEOBJ", safeRelativePath, null, 111, null, ObjectType.Vehicle, ObjectSource.Custom, null, null));
128+
sfm.ObjectIndex.Objects.Add(new ObjectIndexEntry("UNSAFEOBJ", Path.Combine("..", "outside.dat"), null, 222, null, ObjectType.Vehicle, ObjectSource.Custom, null, null));
129+
130+
var obj = new TblObject
131+
{
132+
Id = 1,
133+
Name = "safe-obj",
134+
SubObjectId = 1,
135+
ObjectType = ObjectType.Vehicle,
136+
ObjectSource = ObjectSource.Custom,
137+
Availability = ObjectAvailability.Available,
138+
DatObjects =
139+
[
140+
new TblDatObject { Id = 1, DatName = "SAFEOBJ", DatChecksum = 111, xxHash3 = 1, ObjectId = 1 },
141+
new TblDatObject { Id = 2, DatName = "UNSAFEOBJ", DatChecksum = 222, xxHash3 = 2, ObjectId = 1 },
142+
],
143+
};
144+
145+
var pack = new TblObjectPack
146+
{
147+
Id = 1,
148+
Name = "Object Pack",
149+
Objects = [obj],
150+
};
151+
152+
_ = await db.ObjectPacks.AddAsync(pack);
153+
_ = await db.SaveChangesAsync();
154+
155+
using var response = await httpClient!.GetAsync($"{RoutesV2.Prefix}{RoutesV2.ObjectPacks}/{pack.Id}{RoutesV2.File}");
156+
var bytes = await response.Content.ReadAsByteArrayAsync();
157+
using var archive = new ZipArchive(new MemoryStream(bytes), ZipArchiveMode.Read);
158+
159+
using (Assert.EnterMultipleScope())
160+
{
161+
Assert.That(response.IsSuccessStatusCode, Is.True);
162+
Assert.That(response.Content.Headers.ContentType?.MediaType, Is.EqualTo("application/zip"));
163+
Assert.That(archive.Entries.Select(x => x.FullName), Is.EqualTo(new[] { safeRelativePath.Replace('\\', '/') }));
164+
await using var entryStream = archive.Entries.Single().Open();
165+
using var entryMemoryStream = new MemoryStream();
166+
await entryStream.CopyToAsync(entryMemoryStream);
167+
Assert.That(entryMemoryStream.ToArray(), Is.EqualTo(new byte[] { 4, 3, 2, 1 }));
168+
}
169+
}
170+
}

0 commit comments

Comments
 (0)