Skip to content
167 changes: 136 additions & 31 deletions packages/hub/src/lib/copy-files.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { checkCredentials } from "../utils/checkCredentials";
import { formatBytes } from "../utils/formatBytes";
import { promisesQueue } from "../utils/promisesQueue";
import { toRepoId } from "../utils/toRepoId";
import { eventToGenerator } from "../utils/eventToGenerator";
import type { CommitOperation, CommitParams } from "./commit";
import { commit } from "./commit";
import { downloadFile } from "./download-file";
Expand All @@ -11,6 +12,22 @@ import { listFiles } from "./list-files";
import type { PathInfo } from "./paths-info";
import { pathsInfo } from "./paths-info";

/**
* Progress events yielded by {@link copyFileIter} / {@link copyFilesIter} / {@link copyFolderIter}.
*
* Currently only `fileDownloaded` is emitted: one event per source file that had to be downloaded
* (small git-stored files that can't be copied server-side). Xet-backed files are copied
* server-side and do not produce events.
*/
export interface CopyProgressEvent {
event: "fileDownloaded";
/** Source path of the file that was just downloaded. */
path: string;
/** Number of files downloaded so far (including this one). */
downloaded: number;
/** Total number of files that will be downloaded. */
total: number;
}
const DOWNLOAD_CONCURRENCY = 5;
const PATHS_INFO_BATCH_SIZE = 100;
const MAX_REPORTED_LFS_PATHS = 5;
Expand Down Expand Up @@ -112,6 +129,39 @@ export function copyFile(
});
}

/**
* Async-iterator variant of {@link copyFile} that yields {@link CopyProgressEvent}s while
* downloading non-xet source files (xet-backed files are copied server-side and do not
* emit events). See {@link copyFile} for the semantics.
*
* @example
* ```ts
* for await (const event of copyFileIter({ source, destination, accessToken })) {
* console.log(`downloaded ${event.path} (${event.downloaded}/${event.total})`);
* }
* ```
*/
export function copyFileIter(
params: {
source: CopySource;
destination: CopyDestination;
} & SharedParams,
): AsyncGenerator<CopyProgressEvent, undefined> {
return copyFilesIter({
...(params.accessToken ? { accessToken: params.accessToken } : { credentials: params.credentials }),
destination: params.destination.repo,
files: [
{
source: params.source,
destinationPath: params.destination.path,
},
],
hubUrl: params.hubUrl,
fetch: params.fetch,
abortSignal: params.abortSignal,
});
}

/**
* Copy multiple files (potentially from different source repos/buckets) to the destination
* bucket in a single commit.
Expand Down Expand Up @@ -152,11 +202,31 @@ export async function copyFiles(
files: CopyFilesEntry[];
} & SharedParams,
): Promise<undefined> {
const iterator = copyFilesIter(params);
while (true) {
const res = await iterator.next();
if (res.done) {
return undefined;
}
}
}

/**
* Async-iterator variant of {@link copyFiles} that yields {@link CopyProgressEvent}s while
* downloading non-xet source files (xet-backed files are copied server-side and do not
* emit events). See {@link copyFiles} for the semantics.
*/
export async function* copyFilesIter(
params: {
destination: BucketDesignation;
files: CopyFilesEntry[];
} & SharedParams,
): AsyncGenerator<CopyProgressEvent, undefined> {
if (params.files.length === 0) {
return undefined;
}

const operations = await resolveCopyOperations(params, params.files);
const operations = yield* resolveCopyOperationsIter(params, params.files);

await commit({
...(params.accessToken ? { accessToken: params.accessToken } : { credentials: params.credentials }),
Expand Down Expand Up @@ -210,6 +280,26 @@ export async function copyFolder(
destination: Omit<CopyDestination, "path"> & { path?: string };
} & SharedParams,
): Promise<undefined> {
const iterator = copyFolderIter(params);
while (true) {
const res = await iterator.next();
if (res.done) {
return undefined;
}
}
}

/**
* Async-iterator variant of {@link copyFolder} that yields {@link CopyProgressEvent}s while
* downloading non-xet source files (xet-backed files are copied server-side and do not
* emit events). See {@link copyFolder} for the semantics.
*/
export async function* copyFolderIter(
params: {
source: Omit<CopySource, "path"> & { path?: string };
destination: Omit<CopyDestination, "path"> & { path?: string };
} & SharedParams,
): AsyncGenerator<CopyProgressEvent, undefined> {
const accessToken = checkCredentials(params);
const sourceRepoId = toRepoId(params.source.repo);
const sourcePath = (params.source.path ?? "").replace(/\/+$/, "");
Expand Down Expand Up @@ -273,7 +363,7 @@ export async function copyFolder(
return undefined;
}

await downloadAndFillBlobs({
yield* downloadAndFillBlobsIter({
pendingDownloads,
operations,
accessToken,
Expand All @@ -296,8 +386,12 @@ export async function copyFolder(
/**
* Resolve a list of {@link CopyFilesEntry} entries into `CommitOperation`s, batching
* `pathsInfo` calls per source repo and parallelizing downloads for non-xet files.
* Yields one {@link CopyProgressEvent} per downloaded file.
*/
async function resolveCopyOperations(shared: SharedParams, files: CopyFilesEntry[]): Promise<CommitOperation[]> {
async function* resolveCopyOperationsIter(
shared: SharedParams,
files: CopyFilesEntry[],
): AsyncGenerator<CopyProgressEvent, CommitOperation[]> {
const accessToken = checkCredentials(shared);

// Group files by (source repo, source revision) so we can batch pathsInfo calls.
Expand Down Expand Up @@ -391,7 +485,7 @@ async function resolveCopyOperations(shared: SharedParams, files: CopyFilesEntry
}
}

await downloadAndFillBlobs({
yield* downloadAndFillBlobsIter({
pendingDownloads,
operations,
accessToken,
Expand All @@ -411,39 +505,50 @@ interface PendingDownload {

/**
* Download all `pendingDownloads` in parallel and fill the matching `addOrUpdate`
* placeholder ops in `operations` with the downloaded blob. No-op if the list is empty.
* placeholder ops in `operations` with the downloaded blob. Yields one
* {@link CopyProgressEvent} per file as it completes. No-op if the list is empty.
*/
async function downloadAndFillBlobs(args: {
function downloadAndFillBlobsIter(args: {
pendingDownloads: PendingDownload[];
operations: CommitOperation[];
accessToken: string | undefined;
hubUrl: string | undefined;
fetch: typeof fetch | undefined;
}): Promise<void> {
if (args.pendingDownloads.length === 0) {
return;
}
await promisesQueue(
args.pendingDownloads.map(({ index, repoId, revision, sourcePath }) => async () => {
const blob = await downloadFile({
repo: repoId,
path: sourcePath,
revision,
accessToken: args.accessToken,
hubUrl: args.hubUrl,
fetch: args.fetch,
});
if (!blob) {
throw new Error(`Failed to download '${sourcePath}' from ${repoId.type}s/${repoId.name}`);
}
const op = args.operations[index];
if (op.operation !== "addOrUpdate") {
throw new Error("Internal: expected addOrUpdate placeholder operation");
}
op.content = blob;
}),
DOWNLOAD_CONCURRENCY,
);
}): AsyncGenerator<CopyProgressEvent, void> {
const total = args.pendingDownloads.length;
return eventToGenerator<CopyProgressEvent, void>((yieldCallback, returnCallback, rejectCallback) => {
if (total === 0) {
returnCallback();
return;
}
let downloaded = 0;
promisesQueue(
args.pendingDownloads.map(({ index, repoId, revision, sourcePath }) => async () => {
const blob = await downloadFile({
repo: repoId,
path: sourcePath,
revision,
accessToken: args.accessToken,
hubUrl: args.hubUrl,
fetch: args.fetch,
});
if (!blob) {
throw new Error(`Failed to download '${sourcePath}' from ${repoId.type}s/${repoId.name}`);
}
const op = args.operations[index];
if (op.operation !== "addOrUpdate") {
throw new Error("Internal: expected addOrUpdate placeholder operation");
}
op.content = blob;
downloaded++;
yieldCallback({ event: "fileDownloaded", path: sourcePath, downloaded, total });
}),
DOWNLOAD_CONCURRENCY,
).then(
() => returnCallback(),
(err) => rejectCallback(err),
);
});
}

/**
Expand Down
Loading