@@ -348,6 +348,53 @@ extern "C" {
348348 // Set a callback to be called for each resulting node during graph compute
349349 GGML_API void ggml_backend_sched_set_eval_callback (ggml_backend_sched_t sched , ggml_backend_sched_eval_callback callback , void * user_data );
350350
351+ //
352+ // Meta backend
353+ //
354+
355+ #define GGML_BACKEND_META_MAX_DEVICES 16
356+
357+ enum ggml_backend_meta_split_axis {
358+ // tensor split by tensor dimensions:
359+ GGML_BACKEND_SPLIT_AXIS_0 = 0 ,
360+ GGML_BACKEND_SPLIT_AXIS_1 = 1 ,
361+ GGML_BACKEND_SPLIT_AXIS_2 = 2 ,
362+ GGML_BACKEND_SPLIT_AXIS_3 = 3 ,
363+
364+ GGML_BACKEND_SPLIT_AXIS_MIRRORED = 10 , // all values on all backends
365+ GGML_BACKEND_SPLIT_AXIS_PARTIAL = 11 , // each backend has a partial sum
366+
367+ // for internal bookkeeping only:
368+ GGML_BACKEND_SPLIT_AXIS_NONE = 98 ,
369+ GGML_BACKEND_SPLIT_AXIS_UNKNOWN = 99 ,
370+ };
371+ GGML_API const char * ggml_backend_meta_split_axis_name (enum ggml_backend_meta_split_axis split_axis );
372+
373+ struct ggml_backend_meta_split_state {
374+ enum ggml_backend_meta_split_axis axis ;
375+
376+ // for tensors with axis >= 0 && axis < GGML_MAX_DIMS:
377+ // - each device has a slice of the tensor along the split axis
378+ // - most tensors have n_segments == 1 and a contiguous slice of the tensor data
379+ // - some tensors have an inhomogenenous data layout along the split axis,
380+ // those tensors are divided into segments which are each individually split across devices
381+ // - ne has one entry per segment and device that add up to ggml_tensor::ne for that axis,
382+ // the outer/inner loops are over segments/devices like [seg0_dev0, seg0_dev1, seg1_dev0, seg1_dev1],
383+ // - for example, a transformer may have a fused QKV matrix rather than 3 matrices, those would be 3 separate segments
384+ // that each need to be split individually across devices so that each device gets a slice of Q, K, and V
385+ int64_t ne [16 * GGML_BACKEND_META_MAX_DEVICES ];
386+ uint32_t n_segments ;
387+ };
388+
389+ // function to assign split states for statically allocated tensors, compute tensor split states will be assigned to be compatible:
390+ typedef struct ggml_backend_meta_split_state (* ggml_backend_meta_get_split_state_t )(const struct ggml_tensor * tensor , void * userdata );
391+
392+ // create a new meta device from "simple" devices, meta buffer type/buffer/backend is then derived from this:
393+ // TODO: this looks a bit strange - a backend API creates a device. I think we should try
394+ // express this as a backend registry functionality instead
395+ GGML_API ggml_backend_dev_t ggml_backend_meta_device (
396+ ggml_backend_dev_t * devs , size_t n_devs , ggml_backend_meta_get_split_state_t get_split_state , void * get_split_state_ud );
397+
351398 //
352399 // Utils
353400 //
0 commit comments