Skip to content

Commit 3ee0b27

Browse files
authored
kernel-builder: use branch from build.toml when specified (#485)
Fixes #477.
1 parent ae45d55 commit 3ee0b27

2 files changed

Lines changed: 115 additions & 27 deletions

File tree

kernel-builder/src/upload.rs

Lines changed: 111 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -61,45 +61,57 @@ pub struct UploadArgs {
6161
pub repo_type: RepoTypeArg,
6262
}
6363

64+
/// Get repository and branch from the given arguments, or fallback to
65+
/// reading `build.toml` and/or metadata otherwise.
66+
fn get_repo_and_branch(
67+
kernel_dir: &Path,
68+
repo_id: Option<String>,
69+
branch: Option<String>,
70+
variants: &[PathBuf],
71+
) -> Result<(String, Option<String>)> {
72+
let build = parse_build(kernel_dir);
73+
74+
let build_branch = build
75+
.as_ref()
76+
.ok()
77+
.and_then(|b| b.branch().map(ToOwned::to_owned));
78+
let arg_branch = branch.or(build_branch);
79+
80+
let resolved_repo_id = match repo_id {
81+
Some(id) => id,
82+
None => build
83+
.context("--repo-id is not provided and cannot parse build.toml.")?
84+
.repo_id()
85+
.ok_or_else(|| {
86+
eyre::eyre!("No `general.hub.repo-id` in build.toml. Use --repo-id to specify it.")
87+
})?
88+
.to_owned(),
89+
};
90+
91+
let version_branch =
92+
arg_branch.map_or_else(|| detect_branch_from_metadata(variants), |b| Ok(Some(b)))?;
93+
94+
Ok((resolved_repo_id, version_branch))
95+
}
96+
6497
pub fn run_upload(args: UploadArgs) -> Result<()> {
6598
let api = hf::api()?;
6699
let repo_type: RepoType = args.repo_type.into();
67100
let kernel_dir = check_or_infer_kernel_dir(args.kernel_dir)?;
68101
let kernel_dir = fs::canonicalize(&kernel_dir)
69102
.wrap_err_with(|| format!("Cannot resolve kernel directory `{}`", kernel_dir.display()))?;
70103

71-
let arg_repo_id = match args.repo_id {
72-
Some(id) => id,
73-
None =>
74-
// WARN: parsing must not be moved out of this branch, we want users
75-
// to be able to upload without `build.toml` as long as they
76-
// provide a repo id.
77-
{
78-
parse_build(&kernel_dir)
79-
.context("--repo-id is not provided and cannot parse build.toml.")?
80-
.repo_id()
81-
.ok_or_else(|| {
82-
eyre::eyre!(
83-
"No `general.hub.repo-id` in build.toml. Use --repo-id to specify it."
84-
)
85-
})?
86-
.to_owned()
87-
}
88-
};
89-
90104
let (build_dir, variants) = discover_variants(&kernel_dir)?;
91105
eprintln!(
92106
"Found {} build variant(s) in {}",
93107
variants.len(),
94108
build_dir.display()
95109
);
96110

97-
let version_branch = args
98-
.branch
99-
.map_or_else(|| detect_branch_from_metadata(&variants), |b| Ok(Some(b)))?;
111+
let (repo_id, branch) = get_repo_and_branch(&kernel_dir, args.repo_id, args.branch, &variants)?;
100112

101113
let params = CreateRepoParams::builder()
102-
.repo_id(&arg_repo_id)
114+
.repo_id(&repo_id)
103115
.repo_type(repo_type)
104116
.private(args.private)
105117
.exist_ok(true)
@@ -113,12 +125,12 @@ pub fn run_upload(args: UploadArgs) -> Result<()> {
113125
.trim_end_matches('/')
114126
.strip_prefix("https://huggingface.co/")
115127
.map(|s| s.strip_prefix("kernels/").unwrap_or(s))
116-
.unwrap_or(&arg_repo_id)
128+
.unwrap_or(&repo_id)
117129
.to_owned();
118130

119131
let repo = repo_handle(&api, repo_type, &repo_id);
120132

121-
let is_new_version_branch = if let Some(ref branch) = version_branch {
133+
let is_new_version_branch = if let Some(ref branch) = branch {
122134
let refs_params = RepoListRefsParams::builder().build();
123135
let refs = repo
124136
.list_refs(&refs_params)
@@ -149,7 +161,7 @@ pub fn run_upload(args: UploadArgs) -> Result<()> {
149161
.or_default(),
150162
);
151163

152-
if let Some(ref branch) = version_branch {
164+
if let Some(ref branch) = branch {
153165
let params = RepoListFilesParams {
154166
revision: Some(branch.clone()),
155167
};
@@ -227,7 +239,7 @@ pub fn run_upload(args: UploadArgs) -> Result<()> {
227239
RepoType::Kernel => "kernels/",
228240
_ => "",
229241
};
230-
let tree_path = version_branch
242+
let tree_path = branch
231243
.as_ref()
232244
.map_or(String::new(), |b| format!("/tree/{b}"));
233245
println!("Kernel uploaded: https://hf.co/{type_prefix}{repo_id}{tree_path}");
@@ -781,4 +793,76 @@ mod tests {
781793
let result = discover_build_file(temp_dir.path(), "CARD.md");
782794
assert!(result.is_err());
783795
}
796+
797+
#[test]
798+
fn test_branch_from_build_toml() {
799+
let temp_dir = tempfile::tempdir().unwrap();
800+
let kernel_dir = temp_dir.path();
801+
802+
fs::write(
803+
kernel_dir.join("build.toml"),
804+
r#"[general]
805+
name = "test-kernel"
806+
backends = ["cuda"]
807+
808+
[general.hub]
809+
repo-id = "test/kernel"
810+
branch = "custom-branch"
811+
"#,
812+
)
813+
.unwrap();
814+
815+
let build_dir = kernel_dir.join("build");
816+
let variant = build_dir.join("torch-cuda");
817+
fs::create_dir_all(&variant).unwrap();
818+
fs::write(variant.join("metadata.json"), METADATA_V3).unwrap();
819+
fs::write(variant.join("kernel.so"), "binary").unwrap();
820+
821+
let variants = vec![variant.clone()];
822+
let (repo_id, branch) = get_repo_and_branch(kernel_dir, None, None, &variants).unwrap();
823+
824+
assert_eq!(repo_id, "test/kernel");
825+
assert_eq!(branch, Some("custom-branch".to_owned()));
826+
827+
// Verify commit ops are generated - these would be uploaded to the branch above.
828+
let mut operations = vec![];
829+
collect_build_commit_ops(&build_dir, &variants, &[], false, &mut operations).unwrap();
830+
assert!(!operations.is_empty());
831+
}
832+
833+
#[test]
834+
fn test_args_take_priority_over_files() {
835+
let temp_dir = tempfile::tempdir().unwrap();
836+
let kernel_dir = temp_dir.path();
837+
838+
fs::write(
839+
kernel_dir.join("build.toml"),
840+
r#"[general]
841+
name = "test-kernel"
842+
backends = ["cuda"]
843+
844+
[general.hub]
845+
repo-id = "build-toml/kernel"
846+
branch = "build-toml-branch"
847+
"#,
848+
)
849+
.unwrap();
850+
851+
let build_dir = kernel_dir.join("build");
852+
let variant = build_dir.join("torch-cuda");
853+
fs::create_dir_all(&variant).unwrap();
854+
fs::write(variant.join("metadata.json"), METADATA_V3).unwrap();
855+
856+
let variants = vec![variant];
857+
let (repo_id, branch) = get_repo_and_branch(
858+
kernel_dir,
859+
Some("args/kernel".to_owned()),
860+
Some("args-branch".to_owned()),
861+
&variants,
862+
)
863+
.unwrap();
864+
865+
assert_eq!(repo_id, "args/kernel");
866+
assert_eq!(branch, Some("args-branch".to_owned()));
867+
}
784868
}

kernels-data/src/config/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@ impl Build {
5353
self.kernels.is_empty()
5454
}
5555

56+
pub fn branch(&self) -> Option<&str> {
57+
self.general.hub.as_ref().and_then(|h| h.branch.as_deref())
58+
}
59+
5660
pub fn repo_id(&self) -> Option<&str> {
5761
self.general.hub.as_ref().and_then(|h| h.repo_id.as_deref())
5862
}

0 commit comments

Comments
 (0)