Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 0 additions & 13 deletions build2cmake/src/templates/cpu/torch-binding.cmake

This file was deleted.

9 changes: 0 additions & 9 deletions build2cmake/src/templates/cpu/torch-extension.cmake

This file was deleted.

16 changes: 0 additions & 16 deletions build2cmake/src/templates/cuda/torch-binding.cmake

This file was deleted.

25 changes: 0 additions & 25 deletions build2cmake/src/templates/cuda/torch-extension.cmake

This file was deleted.

17 changes: 0 additions & 17 deletions build2cmake/src/templates/metal/torch-extension.cmake

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#get_torch_gpu_compiler_flags(TORCH_GPU_FLAGS ${GPU_LANG})
#list(APPEND GPU_FLAGS ${TORCH_GPU_FLAGS})
if(GPU_LANG STREQUAL "CUDA")
get_torch_gpu_compiler_flags(TORCH_GPU_FLAGS ${GPU_LANG})
list(APPEND GPU_FLAGS ${TORCH_GPU_FLAGS})
endif()

set(TORCH_{{name}}_SRC
{{ src|join(' ') }}
Expand Down
39 changes: 39 additions & 0 deletions build2cmake/src/templates/torch-extension.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Include Metal shader compilation utilities if needed
if(GPU_LANG STREQUAL "METAL")
include(${CMAKE_CURRENT_LIST_DIR}/cmake/compile-metal.cmake)
endif()

# Define the extension target with unified parameters
define_gpu_extension_target(
{{ ops_name }}
DESTINATION {{ ops_name }}
LANGUAGE ${GPU_LANG}
SOURCES ${SRC}
COMPILE_FLAGS ${GPU_FLAGS}
ARCHITECTURES ${GPU_ARCHES}
USE_SABI 3
WITH_SOABI)

if(NOT (MSVC OR GPU_LANG STREQUAL "SYCL"))
target_link_options({{ ops_name }} PRIVATE -static-libstdc++)
endif()

if(GPU_LANG STREQUAL "SYCL")
target_link_options({{ ops_name }} PRIVATE ${sycl_link_flags})
target_link_libraries({{ ops_name }} PRIVATE dnnl)
endif()

# Compile Metal shaders if any were found
if(GPU_LANG STREQUAL "METAL")
if(ALL_METAL_SOURCES)
compile_metal_shaders({{ ops_name }} "${ALL_METAL_SOURCES}" "${METAL_INCLUDE_DIRS}")
endif()
endif()

{% if platform == 'windows' %}
# Add kernels_install target for huggingface/kernels library layout
add_kernels_install_target({{ ops_name }} "{{ name }}" "${BUILD_VARIANT_NAME}")

# Add local_install target for local development with get_local_kernel()
add_local_install_target({{ ops_name }} "{{ name }}" "${BUILD_VARIANT_NAME}")
{% endif %}
3 changes: 2 additions & 1 deletion build2cmake/src/templates/xpu/preamble.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ add_compile_definitions(USE_XPU)
# Set SYCL-specific flags
# Set comprehensive SYCL compilation and linking flags
set(sycl_link_flags "-fsycl;--offload-compress;-fsycl-targets=spir64_gen,spir64;-Xs;-device pvc,xe-lpg,ats-m150 -options ' -cl-intel-enable-auto-large-GRF-mode -cl-poison-unsupported-fp64-kernels -cl-intel-greater-than-4GB-buffer-required';")
set(sycl_flags "-fsycl;-fhonor-nans;-fhonor-infinities;-fno-associative-math;-fno-approx-func;-fno-sycl-instrument-device-code;--offload-compress;-fsycl-targets=spir64_gen,spir64;")
set(GPU_FLAGS "-fsycl;-fhonor-nans;-fhonor-infinities;-fno-associative-math;-fno-approx-func;-fno-sycl-instrument-device-code;--offload-compress;-fsycl-targets=spir64_gen,spir64;")
set(GPU_ARCHES "")
message(STATUS "Configuring for Intel XPU backend using SYCL")

{% if platform == 'windows' %}
Expand Down
13 changes: 0 additions & 13 deletions build2cmake/src/templates/xpu/torch-binding.cmake

This file was deleted.

23 changes: 0 additions & 23 deletions build2cmake/src/templates/xpu/torch-extension.cmake

This file was deleted.

50 changes: 49 additions & 1 deletion build2cmake/src/torch/common.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use std::io::Write;

use eyre::{Context, Result};
use itertools::Itertools;
use minijinja::{context, Environment};

use crate::config::{Backend, General};
use crate::config::{Backend, General, Torch};
use crate::metadata::Metadata;
use crate::FileSet;

Expand Down Expand Up @@ -64,3 +66,49 @@ where
.collect_vec()
.join(";")
}

pub fn render_binding(
env: &Environment,
torch: &Torch,
name: &str,
write: &mut impl Write,
) -> Result<()> {
env.get_template("torch-binding.cmake")
.wrap_err("Cannot get Torch binding template")?
.render_to_write(
context! {
includes => torch.include.as_ref().map(prefix_and_join_includes),
name => name,
src => torch.src
},
&mut *write,
)
.wrap_err("Cannot render Torch binding template")?;

write.write_all(b"\n")?;

Ok(())
}

pub fn render_extension(
env: &Environment,
name: &str,
ops_name: &str,
write: &mut impl Write,
) -> Result<()> {
env.get_template("torch-extension.cmake")
.wrap_err("Cannot get Torch extension template")?
.render_to_write(
context! {
name => name,
ops_name => ops_name,
platform => std::env::consts::OS,
},
&mut *write,
)
.wrap_err("Cannot render Torch extension template")?;

write.write_all(b"\n")?;

Ok(())
}
63 changes: 3 additions & 60 deletions build2cmake/src/torch/cpu.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use std::{io::Write, path::PathBuf};

use eyre::{bail, Context, Result};
use itertools::Itertools;
use minijinja::{context, Environment};

use crate::config::{Backend, Build, Torch};
use crate::fileset::FileSet;
use crate::torch::common::write_metadata;
use crate::torch::common::write_pyproject_toml;
use crate::torch::common::{
render_binding, render_extension, write_metadata, write_pyproject_toml,
};
use crate::torch::kernel::render_kernel_components;
use crate::torch::kernel_ops_identifier;
use crate::version::Version;
Expand Down Expand Up @@ -103,51 +103,6 @@ fn write_cmake(
Ok(())
}

fn render_binding(
env: &Environment,
torch: &Torch,
name: &str,
write: &mut impl Write,
) -> Result<()> {
env.get_template("cpu/torch-binding.cmake")
.wrap_err("Cannot get Torch binding template")?
.render_to_write(
context! {
includes => torch.include.as_ref().map(prefix_and_join_includes),
name => name,
src => torch.src
},
&mut *write,
)
.wrap_err("Cannot render Torch binding template")?;

write.write_all(b"\n")?;

Ok(())
}

pub fn render_extension(
env: &Environment,
name: &str,
ops_name: &str,
write: &mut impl Write,
) -> Result<()> {
env.get_template("cpu/torch-extension.cmake")
.wrap_err("Cannot get Torch extension template")?
.render_to_write(
context! {
name => name,
ops_name => ops_name,
},
&mut *write,
)
.wrap_err("Cannot render Torch extension template")?;

write.write_all(b"\n")?;

Ok(())
}

fn render_preamble(
env: &Environment,
name: &str,
Expand Down Expand Up @@ -234,15 +189,3 @@ fn write_torch_registration_macros(file_set: &mut FileSet) -> Result<()> {

Ok(())
}

fn prefix_and_join_includes<S>(includes: impl AsRef<[S]>) -> String
where
S: AsRef<str>,
{
includes
.as_ref()
.iter()
.map(|include| format!("${{CMAKE_SOURCE_DIR}}/{}", include.as_ref()))
.collect_vec()
.join(";")
}
Loading
Loading