forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathEmbeddingBackwardKernel.cuh
More file actions
36 lines (29 loc) · 883 Bytes
/
EmbeddingBackwardKernel.cuh
File metadata and controls
36 lines (29 loc) · 883 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/TensorUtils.h>
#include <ATen/NativeFunctions.h>
#include <ATen/AccumulateType.h>
#include <THC/THCDeviceUtils.cuh>
#include <THC/THCTensorMathReduce.cuh>
#include <THC/THCTensorSort.cuh>
#include <THC/THCThrustAllocator.cuh>
#include <THC/THCAtomics.cuh>
#include <thrust/execution_policy.h>
#include <thrust/unique.h>
#include <thrust/device_vector.h>
#pragma once
namespace at {
namespace native {
Tensor embedding_backward_cuda_kernel(
const Tensor &grad,
const Tensor &orig_indices,
const Tensor &sorted_indices,
const Tensor &count,
int64_t num_weights,
int padding_idx = -1,
bool scale_grad_by_freq = false,
bool mode_mean = false,
const Tensor &offset2bag = Tensor(),
const Tensor &bag_size = Tensor(),
const Tensor &per_sample_weights = Tensor());
}}