Skip to content

Commit a1afaed

Browse files
committed
feat(kmeans): add gpu impl.
1 parent ec6529b commit a1afaed

18 files changed

Lines changed: 1410 additions & 220 deletions

bezier_sycl.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import (
1010
//go:generate clang++ -fsycl -fsycl-device-only -fno-sycl-instrument-device-code -fsycl-targets=spirv64 -Xclang -emit-llvm-bc internal/build/bezier_sycl.cpp -o internal/build/device_bezier_kern.bc
1111
//go:generate sycl-post-link -symbols -split=auto -emit-param-info -properties -o internal/build/device_bezier_kern.table internal/build/device_bezier_kern.bc
1212
//go:generate llvm-spirv --sycl-opt -o internal/build/bezier_sycl.spv internal/build/device_bezier_kern_0.bc
13-
//go:generate clang++ -target spirv64-unknown-unknown -S -emit-llvm -x ir internal/build/device_bezier_kern_0.bc -o main.ll
13+
//go:generate clang++ -target spirv64-unknown-unknown -S -emit-llvm -x ir internal/build/device_bezier_kern_0.bc -o internal/build/bezier_sycl.ll
1414
//go:generate llvm-spirv -to-text internal/build/bezier_sycl.spv -o internal/build/bezier_sycl.spt
1515

1616
//go:embed internal/build/bezier_sycl.spv
@@ -27,7 +27,7 @@ func init() {
2727
}
2828

2929
var err error
30-
bezierModel, err = gpu.CreateModuleAndCheckKernels(bezierspv)
30+
bezierModel, err = gpu.ModuleCreateAndCheckKernels(bezierspv)
3131
if err != nil {
3232
return
3333
}
@@ -36,13 +36,13 @@ func init() {
3636
}
3737

3838
func quadraticBezeirGPU(x0, y0, x1, y1, x2, y2, ds float64, p []Point) error {
39-
return gpu.Exec1D(bezierModel, "__sycl_kernel_quadratic", p,
39+
return gpu.Exec1D1Buf(bezierModel, "__sycl_kernel_quadratic", p,
4040
x0, y0, x1, y1, x2, y2, ds,
4141
)
4242
}
4343

4444
func cubicBezeirGPU(x0, y0, x1, y1, x2, y2, x3, y3, ds float64, p []Point) error {
45-
return gpu.Exec1D(bezierModel, "__sycl_kernel_cubic", p,
45+
return gpu.Exec1D1Buf(bezierModel, "__sycl_kernel_cubic", p,
4646
x0, y0, x1, y1, x2, y2, x3, y3, ds,
4747
)
4848
}

color.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,9 @@ var (
2727
// TakeThemeColorsKMeans extracts the k dominant colors from an image using k-means.
2828
//
2929
// TakeThemeColorsKMeans 使用 k-means 算法从图像中提取 k 个主色。
30-
func TakeThemeColorsKMeans(img image.Image, k int) []color.RGBA {
30+
func TakeThemeColorsKMeans(img image.Image, k uint16) []color.RGBA {
3131
ki := newKMeansImage(img, k) // 初始化k个聚类中心
32+
defer ki.destroy()
3233
for {
3334
ki.assign()
3435
ki.update()

color_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ func TestClustersEqual_DiffersOnlyInAlpha(t *testing.T) {
216216
func TestTakecolor_ReturnsKColors(t *testing.T) {
217217
img := solidImage(10, 10, color.RGBA{128, 64, 32, 255})
218218
for k := 1; k <= 5; k++ {
219-
result := TakeThemeColorsKMeans(img, k)
219+
result := TakeThemeColorsKMeans(img, uint16(k))
220220
if len(result) != k {
221221
t.Errorf("takecolor with k=%d returned %d colors, want %d", k, len(result), k)
222222
}

context.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1365,6 +1365,6 @@ func (dc *Context) String() string {
13651365
// TakeThemeColorsKMeans extracts the k dominant colors from the drawn image using k-means.
13661366
//
13671367
// TakeThemeColorsKMeans 使用 k-means 算法从已绘制图像中提取 k 个主色。
1368-
func (dc *Context) TakeThemeColorsKMeans(k int) []color.RGBA {
1368+
func (dc *Context) TakeThemeColorsKMeans(k uint16) []color.RGBA {
13691369
return TakeThemeColorsKMeans(dc.im, k)
13701370
}

internal/build/bezier_sycl.ll

Lines changed: 50 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
; ModuleID = 'device_bezier_kern_0.bc'
2-
source_filename = "bezier_sycl.cpp"
1+
; ModuleID = 'internal/build/device_bezier_kern_0.bc'
2+
source_filename = "internal/build/bezier_sycl.cpp"
33
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64-G1"
44
target triple = "spirv64-unknown-unknown"
55

@@ -8,8 +8,8 @@ target triple = "spirv64-unknown-unknown"
88
@__spirv_BuiltInGlobalInvocationId = external local_unnamed_addr addrspace(1) constant <3 x i64>, align 32
99

1010
; Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(argmem: write, inaccessiblemem: write)
11-
define spir_kernel void @__sycl_kernel_quadratic(double noundef %0, double noundef %1, double noundef %2, double noundef %3, double noundef %4, double noundef %5, double noundef %6, ptr addrspace(1) noundef writeonly align 8 captures(none) %7) local_unnamed_addr #0 !kernel_arg_buffer_location !7 !sycl_used_aspects !8 !sycl_fixed_targets !10 !sycl_kernel_omit_args !11 {
12-
%9 = load i64, ptr addrspace(1) @__spirv_BuiltInGlobalInvocationId, align 32, !noalias !12
11+
define spir_kernel void @__sycl_kernel_quadratic(double noundef %0, double noundef %1, double noundef %2, double noundef %3, double noundef %4, double noundef %5, double noundef %6, ptr addrspace(1) noundef writeonly align 8 captures(none) %7) local_unnamed_addr #0 !kernel_arg_buffer_location !9 !sycl_used_aspects !10 !sycl_fixed_targets !12 !sycl_kernel_omit_args !13 {
12+
%9 = load i64, ptr addrspace(1) @__spirv_BuiltInGlobalInvocationId, align 32, !noalias !14
1313
%10 = icmp ult i64 %9, 2147483648
1414
tail call void @llvm.assume(i1 %10)
1515
%11 = uitofp nneg i64 %9 to double
@@ -23,31 +23,31 @@ define spir_kernel void @__sycl_kernel_quadratic(double noundef %0, double nound
2323
%19 = tail call double @llvm.fmuladd.f64(double %14, double %0, double %18)
2424
%20 = tail call double @llvm.fmuladd.f64(double %17, double %4, double %19)
2525
%21 = getelementptr inbounds %struct.point, ptr addrspace(1) %7, i64 %9
26-
store double %20, ptr addrspace(1) %21, align 8, !tbaa !19
26+
store double %20, ptr addrspace(1) %21, align 8
2727
%22 = fmul double %16, %3
2828
%23 = tail call double @llvm.fmuladd.f64(double %14, double %1, double %22)
2929
%24 = tail call double @llvm.fmuladd.f64(double %17, double %5, double %23)
3030
%25 = getelementptr inbounds i8, ptr addrspace(1) %21, i64 8
31-
store double %24, ptr addrspace(1) %25, align 8, !tbaa !24
31+
store double %24, ptr addrspace(1) %25, align 8
3232
ret void
3333
}
3434

3535
; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
36-
declare !sycl_used_aspects !8 double @llvm.fmuladd.f64(double, double, double) #1
36+
declare !sycl_used_aspects !10 double @llvm.fmuladd.f64(double, double, double) #1
3737

3838
; Function Attrs: nocallback nofree nosync nounwind willreturn memory(inaccessiblemem: write)
3939
declare void @llvm.assume(i1 noundef) #2
4040

4141
; Function Attrs: mustprogress norecurse nounwind
42-
define spir_kernel void @__sycl_kernel_cubic(double noundef %0, double noundef %1, double noundef %2, double noundef %3, double noundef %4, double noundef %5, double noundef %6, double noundef %7, double noundef %8, ptr addrspace(1) noundef align 8 %9) local_unnamed_addr #3 !kernel_arg_buffer_location !25 !sycl_used_aspects !8 !sycl_fixed_targets !10 !sycl_kernel_omit_args !26 {
42+
define spir_kernel void @__sycl_kernel_cubic(double noundef %0, double noundef %1, double noundef %2, double noundef %3, double noundef %4, double noundef %5, double noundef %6, double noundef %7, double noundef %8, ptr addrspace(1) noundef align 8 %9) local_unnamed_addr #3 !kernel_arg_buffer_location !21 !sycl_used_aspects !10 !sycl_fixed_targets !12 !sycl_kernel_omit_args !22 {
4343
%11 = addrspacecast ptr addrspace(1) %9 to ptr addrspace(4)
4444
tail call spir_func void @cubic(double noundef %0, double noundef %1, double noundef %2, double noundef %3, double noundef %4, double noundef %5, double noundef %6, double noundef %7, double noundef %8, ptr addrspace(4) noundef %11) #5
4545
ret void
4646
}
4747

4848
; Function Attrs: mustprogress norecurse nounwind
49-
define internal spir_func void @cubic(double noundef %0, double noundef %1, double noundef %2, double noundef %3, double noundef %4, double noundef %5, double noundef %6, double noundef %7, double noundef %8, ptr addrspace(4) noundef %9) local_unnamed_addr #4 !sycl_used_aspects !8 {
50-
%11 = load i64, ptr addrspace(1) @__spirv_BuiltInGlobalInvocationId, align 32, !noalias !27
49+
define internal spir_func void @cubic(double noundef %0, double noundef %1, double noundef %2, double noundef %3, double noundef %4, double noundef %5, double noundef %6, double noundef %7, double noundef %8, ptr addrspace(4) noundef %9) local_unnamed_addr #4 !sycl_used_aspects !10 {
50+
%11 = load i64, ptr addrspace(1) @__spirv_BuiltInGlobalInvocationId, align 32, !noalias !23
5151
%12 = icmp ult i64 %11, 2147483648
5252
tail call void @llvm.assume(i1 %12)
5353
%13 = uitofp nneg i64 %11 to double
@@ -67,60 +67,57 @@ define internal spir_func void @cubic(double noundef %0, double noundef %1, doub
6767
%27 = tail call double @llvm.fmuladd.f64(double %22, double %4, double %26)
6868
%28 = tail call double @llvm.fmuladd.f64(double %24, double %6, double %27)
6969
%29 = getelementptr inbounds nuw %struct.point, ptr addrspace(4) %9, i64 %11
70-
store double %28, ptr addrspace(4) %29, align 8, !tbaa !19
70+
store double %28, ptr addrspace(4) %29, align 8
7171
%30 = fmul double %20, %3
7272
%31 = tail call double @llvm.fmuladd.f64(double %17, double %1, double %30)
7373
%32 = tail call double @llvm.fmuladd.f64(double %22, double %5, double %31)
7474
%33 = tail call double @llvm.fmuladd.f64(double %24, double %7, double %32)
7575
%34 = getelementptr inbounds nuw i8, ptr addrspace(4) %29, i64 8
76-
store double %33, ptr addrspace(4) %34, align 8, !tbaa !24
76+
store double %33, ptr addrspace(4) %34, align 8
7777
ret void
7878
}
7979

80-
attributes #0 = { mustprogress nofree norecurse nosync nounwind willreturn memory(argmem: write, inaccessiblemem: write) "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "sycl-entry-point" "sycl-module-id"="bezier_sycl.cpp" "sycl-nd-range-kernel"="1" "sycl-optlevel"="2" "uniform-work-group-size"="true" }
80+
attributes #0 = { mustprogress nofree norecurse nosync nounwind willreturn memory(argmem: write, inaccessiblemem: write) "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "sycl-entry-point" "sycl-module-id"="internal/build/bezier_sycl.cpp" "sycl-nd-range-kernel"="1" "sycl-optlevel"="2" "uniform-work-group-size"="true" }
8181
attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }
8282
attributes #2 = { nocallback nofree nosync nounwind willreturn memory(inaccessiblemem: write) }
83-
attributes #3 = { mustprogress norecurse nounwind "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "sycl-entry-point" "sycl-module-id"="bezier_sycl.cpp" "sycl-nd-range-kernel"="1" "sycl-optlevel"="2" "uniform-work-group-size"="true" }
83+
attributes #3 = { mustprogress norecurse nounwind "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "sycl-entry-point" "sycl-module-id"="internal/build/bezier_sycl.cpp" "sycl-nd-range-kernel"="1" "sycl-optlevel"="2" "uniform-work-group-size"="true" }
8484
attributes #4 = { mustprogress norecurse nounwind "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "sycl-nd-range-kernel"="1" "sycl-optlevel"="2" }
8585
attributes #5 = { nounwind }
8686

87-
!llvm.module.flags = !{!0, !1, !2}
88-
!opencl.spir.version = !{!3}
89-
!spirv.Source = !{!4}
90-
!llvm.ident = !{!5}
91-
!sycl-esimd-split-status = !{!6}
87+
!llvm.linker.options = !{!0, !1}
88+
!llvm.module.flags = !{!2, !3, !4}
89+
!opencl.spir.version = !{!5}
90+
!spirv.Source = !{!6}
91+
!llvm.ident = !{!7}
92+
!sycl-esimd-split-status = !{!8}
9293

93-
!0 = !{i32 1, !"wchar_size", i32 4}
94-
!1 = !{i32 1, !"sycl-device", i32 1}
95-
!2 = !{i32 7, !"frame-pointer", i32 2}
96-
!3 = !{i32 1, i32 2}
97-
!4 = !{i32 4, i32 100000}
98-
!5 = !{!"clang version 21.0.0git (https://github.com/intel/llvm d5f649b706f63b5c74e1929bc95db8de91085560)"}
99-
!6 = !{i8 0}
100-
!7 = !{i32 -1, i32 -1, i32 -1, i32 -1, i32 -1, i32 -1, i32 -1, i32 -1}
101-
!8 = !{!9}
102-
!9 = !{!"fp64", i32 6}
103-
!10 = !{}
104-
!11 = !{i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false}
105-
!12 = !{!13, !15, !17}
106-
!13 = distinct !{!13, !14, !"_ZN7__spirv29InitSizesSTGlobalInvocationIdILi1EN4sycl3_V12idILi1EEEE8initSizeEv: argument 0"}
107-
!14 = distinct !{!14, !"_ZN7__spirv29InitSizesSTGlobalInvocationIdILi1EN4sycl3_V12idILi1EEEE8initSizeEv"}
108-
!15 = distinct !{!15, !16, !"_ZN7__spirv22initGlobalInvocationIdILi1EN4sycl3_V12idILi1EEEEET0_v: argument 0"}
109-
!16 = distinct !{!16, !"_ZN7__spirv22initGlobalInvocationIdILi1EN4sycl3_V12idILi1EEEEET0_v"}
110-
!17 = distinct !{!17, !18, !"_ZNK4sycl3_V17nd_itemILi1EE13get_global_idEv: argument 0"}
111-
!18 = distinct !{!18, !"_ZNK4sycl3_V17nd_itemILi1EE13get_global_idEv"}
112-
!19 = !{!20, !21, i64 0}
113-
!20 = !{!"_ZTS5point", !21, i64 0, !21, i64 8}
114-
!21 = !{!"double", !22, i64 0}
115-
!22 = !{!"omnipotent char", !23, i64 0}
116-
!23 = !{!"Simple C++ TBAA"}
117-
!24 = !{!20, !21, i64 8}
118-
!25 = !{i32 -1, i32 -1, i32 -1, i32 -1, i32 -1, i32 -1, i32 -1, i32 -1, i32 -1, i32 -1}
119-
!26 = !{i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false}
120-
!27 = !{!28, !30, !32}
121-
!28 = distinct !{!28, !29, !"_ZN7__spirv29InitSizesSTGlobalInvocationIdILi1EN4sycl3_V12idILi1EEEE8initSizeEv: argument 0"}
122-
!29 = distinct !{!29, !"_ZN7__spirv29InitSizesSTGlobalInvocationIdILi1EN4sycl3_V12idILi1EEEE8initSizeEv"}
123-
!30 = distinct !{!30, !31, !"_ZN7__spirv22initGlobalInvocationIdILi1EN4sycl3_V12idILi1EEEEET0_v: argument 0"}
124-
!31 = distinct !{!31, !"_ZN7__spirv22initGlobalInvocationIdILi1EN4sycl3_V12idILi1EEEEET0_v"}
125-
!32 = distinct !{!32, !33, !"_ZNK4sycl3_V17nd_itemILi1EE13get_global_idEv: argument 0"}
126-
!33 = distinct !{!33, !"_ZNK4sycl3_V17nd_itemILi1EE13get_global_idEv"}
94+
!0 = !{!"-llibcpmt"}
95+
!1 = !{!"/alternatename:_Avx2WmemEnabled=_Avx2WmemEnabledWeakValue"}
96+
!2 = !{i32 1, !"wchar_size", i32 2}
97+
!3 = !{i32 1, !"sycl-device", i32 1}
98+
!4 = !{i32 7, !"frame-pointer", i32 2}
99+
!5 = !{i32 1, i32 2}
100+
!6 = !{i32 4, i32 100000}
101+
!7 = !{!"clang version 21.0.0git (https://github.com/intel/llvm d5f649b706f63b5c74e1929bc95db8de91085560)"}
102+
!8 = !{i8 0}
103+
!9 = !{i32 -1, i32 -1, i32 -1, i32 -1, i32 -1, i32 -1, i32 -1, i32 -1}
104+
!10 = !{!11}
105+
!11 = !{!"fp64", i32 6}
106+
!12 = !{}
107+
!13 = !{i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false}
108+
!14 = !{!15, !17, !19}
109+
!15 = distinct !{!15, !16, !"_ZN7__spirv29InitSizesSTGlobalInvocationIdILi1EN4sycl3_V12idILi1EEEE8initSizeEv: argument 0"}
110+
!16 = distinct !{!16, !"_ZN7__spirv29InitSizesSTGlobalInvocationIdILi1EN4sycl3_V12idILi1EEEE8initSizeEv"}
111+
!17 = distinct !{!17, !18, !"_ZN7__spirv22initGlobalInvocationIdILi1EN4sycl3_V12idILi1EEEEET0_v: argument 0"}
112+
!18 = distinct !{!18, !"_ZN7__spirv22initGlobalInvocationIdILi1EN4sycl3_V12idILi1EEEEET0_v"}
113+
!19 = distinct !{!19, !20, !"_ZNK4sycl3_V17nd_itemILi1EE13get_global_idEv: argument 0"}
114+
!20 = distinct !{!20, !"_ZNK4sycl3_V17nd_itemILi1EE13get_global_idEv"}
115+
!21 = !{i32 -1, i32 -1, i32 -1, i32 -1, i32 -1, i32 -1, i32 -1, i32 -1, i32 -1, i32 -1}
116+
!22 = !{i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false}
117+
!23 = !{!24, !26, !28}
118+
!24 = distinct !{!24, !25, !"_ZN7__spirv29InitSizesSTGlobalInvocationIdILi1EN4sycl3_V12idILi1EEEE8initSizeEv: argument 0"}
119+
!25 = distinct !{!25, !"_ZN7__spirv29InitSizesSTGlobalInvocationIdILi1EN4sycl3_V12idILi1EEEE8initSizeEv"}
120+
!26 = distinct !{!26, !27, !"_ZN7__spirv22initGlobalInvocationIdILi1EN4sycl3_V12idILi1EEEEET0_v: argument 0"}
121+
!27 = distinct !{!27, !"_ZN7__spirv22initGlobalInvocationIdILi1EN4sycl3_V12idILi1EEEEET0_v"}
122+
!28 = distinct !{!28, !29, !"_ZNK4sycl3_V17nd_itemILi1EE13get_global_idEv: argument 0"}
123+
!29 = distinct !{!29, !"_ZNK4sycl3_V17nd_itemILi1EE13get_global_idEv"}

internal/build/kmeans_ocl.cl

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
kernel void assign_first_iter(
2+
read_only image2d_t inputImg,
3+
sampler_t smp,
4+
read_only image2d_t clusters,
5+
__global ushort* clusterAssignments,
6+
write_only image2d_t sampleResult)
7+
{
8+
uint x = get_global_id(0);
9+
uint y = get_global_id(1);
10+
11+
uint inpW = get_image_width(inputImg);
12+
uint inpH = get_image_height(inputImg);
13+
14+
uint dstW = get_image_width(sampleResult);
15+
uint dstH = get_image_height(sampleResult);
16+
17+
if (x >= dstW || y >= dstH) {
18+
return;
19+
}
20+
21+
uint k = get_image_width(clusters);
22+
23+
float4 pixel;
24+
if (inpW == dstW && inpH == dstH) {
25+
pixel = read_imagef(inputImg, (int2)(x, y));
26+
} else {
27+
float2 normCoord = (float2)(
28+
(float)x / (float)dstW,
29+
(float)y / (float)dstH
30+
);
31+
pixel = read_imagef(inputImg, smp, normCoord);
32+
}
33+
write_imagef(sampleResult, (int2)(x, y), pixel);
34+
35+
float minDistance = FLT_MAX;
36+
ushort assign = USHRT_MAX;
37+
for (int i = 0; i < k; i++) {
38+
float4 cluster = read_imagef(clusters, (int2)(i, 0));
39+
float4 diff = pixel - cluster;
40+
diff.w = 0;
41+
float d = dot(diff, diff);
42+
if (d < minDistance) {
43+
minDistance = d;
44+
assign = (ushort)i;
45+
}
46+
}
47+
clusterAssignments[x+y*dstW] = assign;
48+
}
49+
50+
kernel void assign_remaining_iter(
51+
read_only image2d_t sampleResult,
52+
read_only image2d_t clusters,
53+
__global ushort* clusterAssignments)
54+
{
55+
uint x = get_global_id(0);
56+
uint y = get_global_id(1);
57+
58+
uint dstW = get_image_width(sampleResult);
59+
uint dstH = get_image_height(sampleResult);
60+
61+
if (x >= dstW || y >= dstH) {
62+
return;
63+
}
64+
65+
uint k = get_image_width(clusters);
66+
67+
float4 pixel = read_imagef(sampleResult, (int2)(x, y));
68+
69+
float minDistance = FLT_MAX;
70+
ushort assign = USHRT_MAX;
71+
for (int i = 0; i < k; i++) {
72+
float4 cluster = read_imagef(clusters, (int2)(i, 0));
73+
float4 diff = pixel - cluster;
74+
diff.w = 0;
75+
float d = dot(diff, diff);
76+
if (d < minDistance) {
77+
minDistance = d;
78+
assign = (ushort)i;
79+
}
80+
}
81+
clusterAssignments[x+y*dstW] = assign;
82+
}

0 commit comments

Comments
 (0)