@@ -4,13 +4,14 @@ use std::io::Write;
44use std:: path:: PathBuf ;
55
66use eyre:: { bail, Context , Result } ;
7- use itertools:: Itertools ;
87use minijinja:: { context, Environment } ;
98
10- use super :: common:: write_pyproject_toml;
11- use super :: kernel_ops_identifier;
12- use crate :: config:: { Backend , Build , Dependency , Kernel , Torch } ;
9+ use crate :: config:: { Backend , Build , Dependency , Torch } ;
10+ use crate :: torch:: common:: prefix_and_join_includes;
1311use crate :: torch:: common:: write_metadata;
12+ use crate :: torch:: common:: write_pyproject_toml;
13+ use crate :: torch:: kernel:: render_kernel_components;
14+ use crate :: torch:: kernel_ops_identifier;
1415use crate :: version:: Version ;
1516use crate :: FileSet ;
1617
@@ -181,13 +182,7 @@ fn write_cmake(
181182
182183 render_binding ( env, torch, name, cmake_writer) ?;
183184
184- for ( kernel_name, kernel) in build
185- . kernels
186- . iter ( )
187- . filter ( |( _, kernel) | kernel. backend ( ) == backend)
188- {
189- render_kernel ( env, kernel_name, kernel, cmake_writer) ?;
190- }
185+ render_kernel_components ( env, build, cmake_writer) ?;
191186
192187 render_extension ( env, name, ops_name, cmake_writer) ?;
193188
@@ -312,71 +307,6 @@ fn render_deps(
312307 Ok ( ( ) )
313308}
314309
315- pub fn render_kernel (
316- env : & Environment ,
317- kernel_name : & str ,
318- kernel : & Kernel ,
319- write : & mut impl Write ,
320- ) -> Result < ( ) > {
321- // Easier to do in Rust than Jinja.
322- let sources = kernel
323- . src ( )
324- . iter ( )
325- . map ( |src| format ! ( "\" {src}\" " ) )
326- . collect_vec ( )
327- . join ( "\n " ) ;
328-
329- let ( cuda_capabilities, rocm_archs, cuda_flags, hip_flags, cuda_minver) = match kernel {
330- Kernel :: Cuda {
331- cuda_capabilities,
332- cuda_flags,
333- cuda_minver,
334- ..
335- } => (
336- cuda_capabilities. as_deref ( ) ,
337- None ,
338- cuda_flags. as_deref ( ) ,
339- None ,
340- cuda_minver. as_ref ( ) ,
341- ) ,
342- Kernel :: Rocm {
343- rocm_archs,
344- hip_flags,
345- ..
346- } => (
347- None ,
348- rocm_archs. as_deref ( ) ,
349- None ,
350- hip_flags. as_deref ( ) ,
351- None ,
352- ) ,
353- _ => unreachable ! ( "Unsupported kernel type for CUDA rendering" ) ,
354- } ;
355-
356- env. get_template ( "cuda/kernel.cmake" )
357- . wrap_err ( "Cannot get kernel template" ) ?
358- . render_to_write (
359- context ! {
360- cuda_capabilities => cuda_capabilities,
361- cuda_flags => cuda_flags. map( |flags| flags. join( ";" ) ) ,
362- cuda_minver => cuda_minver. map( ToString :: to_string) ,
363- cxx_flags => kernel. cxx_flags( ) . map( |flags| flags. join( ";" ) ) ,
364- rocm_archs => rocm_archs,
365- hip_flags => hip_flags. map( |flags| flags. join( ";" ) ) ,
366- includes => kernel. include( ) . map( prefix_and_join_includes) ,
367- kernel_name => kernel_name,
368- supports_hipify => matches!( kernel, Kernel :: Rocm { .. } ) ,
369- sources => sources,
370- } ,
371- & mut * write,
372- )
373- . wrap_err ( "Cannot render kernel template" ) ?;
374-
375- write. write_all ( b"\n " ) ?;
376-
377- Ok ( ( ) )
378- }
379-
380310pub fn render_extension (
381311 env : & Environment ,
382312 name : & str ,
@@ -428,15 +358,3 @@ pub fn render_preamble(
428358
429359 Ok ( ( ) )
430360}
431-
432- fn prefix_and_join_includes < S > ( includes : impl AsRef < [ S ] > ) -> String
433- where
434- S : AsRef < str > ,
435- {
436- includes
437- . as_ref ( )
438- . iter ( )
439- . map ( |include| format ! ( "${{CMAKE_SOURCE_DIR}}/{}" , include. as_ref( ) ) )
440- . collect_vec ( )
441- . join ( ";" )
442- }
0 commit comments