@@ -61,45 +61,57 @@ pub struct UploadArgs {
6161 pub repo_type : RepoTypeArg ,
6262}
6363
64+ /// Get repository and branch from the given arguments, or fallback to
65+ /// reading `build.toml` and/or metadata otherwise.
66+ fn get_repo_and_branch (
67+ kernel_dir : & Path ,
68+ repo_id : Option < String > ,
69+ branch : Option < String > ,
70+ variants : & [ PathBuf ] ,
71+ ) -> Result < ( String , Option < String > ) > {
72+ let build = parse_build ( kernel_dir) ;
73+
74+ let build_branch = build
75+ . as_ref ( )
76+ . ok ( )
77+ . and_then ( |b| b. branch ( ) . map ( ToOwned :: to_owned) ) ;
78+ let arg_branch = branch. or ( build_branch) ;
79+
80+ let resolved_repo_id = match repo_id {
81+ Some ( id) => id,
82+ None => build
83+ . context ( "--repo-id is not provided and cannot parse build.toml." ) ?
84+ . repo_id ( )
85+ . ok_or_else ( || {
86+ eyre:: eyre!( "No `general.hub.repo-id` in build.toml. Use --repo-id to specify it." )
87+ } ) ?
88+ . to_owned ( ) ,
89+ } ;
90+
91+ let version_branch =
92+ arg_branch. map_or_else ( || detect_branch_from_metadata ( variants) , |b| Ok ( Some ( b) ) ) ?;
93+
94+ Ok ( ( resolved_repo_id, version_branch) )
95+ }
96+
6497pub fn run_upload ( args : UploadArgs ) -> Result < ( ) > {
6598 let api = hf:: api ( ) ?;
6699 let repo_type: RepoType = args. repo_type . into ( ) ;
67100 let kernel_dir = check_or_infer_kernel_dir ( args. kernel_dir ) ?;
68101 let kernel_dir = fs:: canonicalize ( & kernel_dir)
69102 . wrap_err_with ( || format ! ( "Cannot resolve kernel directory `{}`" , kernel_dir. display( ) ) ) ?;
70103
71- let arg_repo_id = match args. repo_id {
72- Some ( id) => id,
73- None =>
74- // WARN: parsing must not be moved out of this branch, we want users
75- // to be able to upload without `build.toml` as long as they
76- // provide a repo id.
77- {
78- parse_build ( & kernel_dir)
79- . context ( "--repo-id is not provided and cannot parse build.toml." ) ?
80- . repo_id ( )
81- . ok_or_else ( || {
82- eyre:: eyre!(
83- "No `general.hub.repo-id` in build.toml. Use --repo-id to specify it."
84- )
85- } ) ?
86- . to_owned ( )
87- }
88- } ;
89-
90104 let ( build_dir, variants) = discover_variants ( & kernel_dir) ?;
91105 eprintln ! (
92106 "Found {} build variant(s) in {}" ,
93107 variants. len( ) ,
94108 build_dir. display( )
95109 ) ;
96110
97- let version_branch = args
98- . branch
99- . map_or_else ( || detect_branch_from_metadata ( & variants) , |b| Ok ( Some ( b) ) ) ?;
111+ let ( repo_id, branch) = get_repo_and_branch ( & kernel_dir, args. repo_id , args. branch , & variants) ?;
100112
101113 let params = CreateRepoParams :: builder ( )
102- . repo_id ( & arg_repo_id )
114+ . repo_id ( & repo_id )
103115 . repo_type ( repo_type)
104116 . private ( args. private )
105117 . exist_ok ( true )
@@ -113,12 +125,12 @@ pub fn run_upload(args: UploadArgs) -> Result<()> {
113125 . trim_end_matches ( '/' )
114126 . strip_prefix ( "https://huggingface.co/" )
115127 . map ( |s| s. strip_prefix ( "kernels/" ) . unwrap_or ( s) )
116- . unwrap_or ( & arg_repo_id )
128+ . unwrap_or ( & repo_id )
117129 . to_owned ( ) ;
118130
119131 let repo = repo_handle ( & api, repo_type, & repo_id) ;
120132
121- let is_new_version_branch = if let Some ( ref branch) = version_branch {
133+ let is_new_version_branch = if let Some ( ref branch) = branch {
122134 let refs_params = RepoListRefsParams :: builder ( ) . build ( ) ;
123135 let refs = repo
124136 . list_refs ( & refs_params)
@@ -149,7 +161,7 @@ pub fn run_upload(args: UploadArgs) -> Result<()> {
149161 . or_default ( ) ,
150162 ) ;
151163
152- if let Some ( ref branch) = version_branch {
164+ if let Some ( ref branch) = branch {
153165 let params = RepoListFilesParams {
154166 revision : Some ( branch. clone ( ) ) ,
155167 } ;
@@ -227,7 +239,7 @@ pub fn run_upload(args: UploadArgs) -> Result<()> {
227239 RepoType :: Kernel => "kernels/" ,
228240 _ => "" ,
229241 } ;
230- let tree_path = version_branch
242+ let tree_path = branch
231243 . as_ref ( )
232244 . map_or ( String :: new ( ) , |b| format ! ( "/tree/{b}" ) ) ;
233245 println ! ( "Kernel uploaded: https://hf.co/{type_prefix}{repo_id}{tree_path}" ) ;
@@ -781,4 +793,76 @@ mod tests {
781793 let result = discover_build_file ( temp_dir. path ( ) , "CARD.md" ) ;
782794 assert ! ( result. is_err( ) ) ;
783795 }
796+
797+ #[ test]
798+ fn test_branch_from_build_toml ( ) {
799+ let temp_dir = tempfile:: tempdir ( ) . unwrap ( ) ;
800+ let kernel_dir = temp_dir. path ( ) ;
801+
802+ fs:: write (
803+ kernel_dir. join ( "build.toml" ) ,
804+ r#"[general]
805+ name = "test-kernel"
806+ backends = ["cuda"]
807+
808+ [general.hub]
809+ repo-id = "test/kernel"
810+ branch = "custom-branch"
811+ "# ,
812+ )
813+ . unwrap ( ) ;
814+
815+ let build_dir = kernel_dir. join ( "build" ) ;
816+ let variant = build_dir. join ( "torch-cuda" ) ;
817+ fs:: create_dir_all ( & variant) . unwrap ( ) ;
818+ fs:: write ( variant. join ( "metadata.json" ) , METADATA_V3 ) . unwrap ( ) ;
819+ fs:: write ( variant. join ( "kernel.so" ) , "binary" ) . unwrap ( ) ;
820+
821+ let variants = vec ! [ variant. clone( ) ] ;
822+ let ( repo_id, branch) = get_repo_and_branch ( kernel_dir, None , None , & variants) . unwrap ( ) ;
823+
824+ assert_eq ! ( repo_id, "test/kernel" ) ;
825+ assert_eq ! ( branch, Some ( "custom-branch" . to_owned( ) ) ) ;
826+
827+ // Verify commit ops are generated - these would be uploaded to the branch above.
828+ let mut operations = vec ! [ ] ;
829+ collect_build_commit_ops ( & build_dir, & variants, & [ ] , false , & mut operations) . unwrap ( ) ;
830+ assert ! ( !operations. is_empty( ) ) ;
831+ }
832+
833+ #[ test]
834+ fn test_args_take_priority_over_files ( ) {
835+ let temp_dir = tempfile:: tempdir ( ) . unwrap ( ) ;
836+ let kernel_dir = temp_dir. path ( ) ;
837+
838+ fs:: write (
839+ kernel_dir. join ( "build.toml" ) ,
840+ r#"[general]
841+ name = "test-kernel"
842+ backends = ["cuda"]
843+
844+ [general.hub]
845+ repo-id = "build-toml/kernel"
846+ branch = "build-toml-branch"
847+ "# ,
848+ )
849+ . unwrap ( ) ;
850+
851+ let build_dir = kernel_dir. join ( "build" ) ;
852+ let variant = build_dir. join ( "torch-cuda" ) ;
853+ fs:: create_dir_all ( & variant) . unwrap ( ) ;
854+ fs:: write ( variant. join ( "metadata.json" ) , METADATA_V3 ) . unwrap ( ) ;
855+
856+ let variants = vec ! [ variant] ;
857+ let ( repo_id, branch) = get_repo_and_branch (
858+ kernel_dir,
859+ Some ( "args/kernel" . to_owned ( ) ) ,
860+ Some ( "args-branch" . to_owned ( ) ) ,
861+ & variants,
862+ )
863+ . unwrap ( ) ;
864+
865+ assert_eq ! ( repo_id, "args/kernel" ) ;
866+ assert_eq ! ( branch, Some ( "args-branch" . to_owned( ) ) ) ;
867+ }
784868}
0 commit comments