| | |
| | |
| | |
| | |
| | |
| |
|
| | #include <torch/extension.h> |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | inline void vvt_dot(float *a, float *b, float *out, int A, int B) { |
| | for (int i=0; i<A; i++) { |
| | float * bi = b; |
| | for (int j=0; j<B; j++) { |
| | *out += (*a) * (*bi); |
| | out++; |
| | bi++; |
| | } |
| | a++; |
| | } |
| | } |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | inline void vm_dot(float *v, float *m, float *out, int A, int B) { |
| | |
| | |
| | for (int i=0; i<B; i++) { |
| | out[i] = 0; |
| | } |
| |
|
| | for (int i=0; i<A; i++) { |
| | float *oi = out; |
| | for (int j=0; j<B; j++) { |
| | *oi += (*v) * (*m); |
| | oi++; |
| | m++; |
| | } |
| | v++; |
| | } |
| | } |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | inline void vmt_dot(float *v, float *m, float *out, int A, int B) { |
| | for (int i=0; i<A; i++) { |
| | float *vi = v; |
| | float s = 0; |
| | for (int j=0; j<B; j++) { |
| | s += (*vi) * (*m); |
| | vi++; |
| | m++; |
| | } |
| | |
| | *out = s; |
| | out++; |
| | } |
| | } |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | void causal_dot_product( |
| | const torch::Tensor queries, |
| | const torch::Tensor keys, |
| | const torch::Tensor values, |
| | torch::Tensor product |
| | ) { |
| | |
| | int N = queries.size(0); |
| | int H = queries.size(1); |
| | int L = queries.size(2); |
| | int E = queries.size(3); |
| | int M = values.size(3); |
| |
|
| | |
| | auto qa = queries.accessor<float, 4>(); |
| | auto ka = keys.accessor<float, 4>(); |
| | auto va = values.accessor<float, 4>(); |
| | auto pa = product.accessor<float, 4>(); |
| |
|
| | #pragma omp parallel for collapse(2) |
| | for (int n=0; n<N; n++) { |
| | for (int h=0; h<H; h++) { |
| | auto kv = torch::zeros({E, M}, queries.options()); |
| | float *kvp = kv.data_ptr<float>(); |
| | for (int l=0; l<L; l++) { |
| | vvt_dot( |
| | &ka[n][h][l][0], |
| | &va[n][h][l][0], |
| | kvp, |
| | E, |
| | M |
| | ); |
| | vm_dot( |
| | &qa[n][h][l][0], |
| | kvp, |
| | &pa[n][h][l][0], |
| | E, |
| | M |
| | ); |
| | } |
| | } |
| | } |
| | } |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | void causal_dot_backward( |
| | const torch::Tensor queries, |
| | const torch::Tensor keys, |
| | const torch::Tensor values, |
| | const torch::Tensor grad_out, |
| | torch::Tensor grad_queries, |
| | torch::Tensor grad_keys, |
| | torch::Tensor grad_values |
| | ) { |
| | |
| | int N = queries.size(0); |
| | int H = queries.size(1); |
| | int L = queries.size(2); |
| | int E = queries.size(3); |
| | int M = values.size(3); |
| |
|
| | |
| | auto qa = queries.accessor<float, 4>(); |
| | auto ka = keys.accessor<float, 4>(); |
| | auto va = values.accessor<float, 4>(); |
| | auto ga = grad_out.accessor<float, 4>(); |
| | auto gqa = grad_queries.accessor<float, 4>(); |
| | auto gka = grad_keys.accessor<float, 4>(); |
| | auto gva = grad_values.accessor<float, 4>(); |
| |
|
| | #pragma omp parallel for collapse(2) |
| | for (int n=0; n<N; n++) { |
| | for (int h=0; h<H; h++) { |
| | auto kv = torch::zeros({E, M}, queries.options()); |
| | float *kvp = kv.data_ptr<float>(); |
| |
|
| | |
| | for (int l=0; l<L; l++) { |
| | vvt_dot( |
| | &ka[n][h][l][0], |
| | &va[n][h][l][0], |
| | kvp, |
| | E, |
| | M |
| | ); |
| | vmt_dot( |
| | &ga[n][h][l][0], |
| | kvp, |
| | &gqa[n][h][l][0], |
| | E, |
| | M |
| | ); |
| | } |
| |
|
| | |
| | kv.zero_(); |
| | for (int l=L-1; l>=0; l--) { |
| | vvt_dot( |
| | &qa[n][h][l][0], |
| | &ga[n][h][l][0], |
| | kvp, |
| | E, |
| | M |
| | ); |
| | vmt_dot( |
| | &va[n][h][l][0], |
| | kvp, |
| | &gka[n][h][l][0], |
| | E, |
| | M |
| | ); |
| | vm_dot( |
| | &ka[n][h][l][0], |
| | kvp, |
| | &gva[n][h][l][0], |
| | E, |
| | M |
| | ); |
| | } |
| | } |
| | } |
| | } |
| |
|
| |
|
| | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { |
| | m.def( |
| | "causal_dot_product", |
| | &causal_dot_product, |
| | "Compute the weighted sum of values but attending only to previous " |
| | "values." |
| | ); |
| | m.def( |
| | "causal_dot_backward", |
| | &causal_dot_backward, |
| | "Compute the gradient of queries, keys and values given the gradient " |
| | "of causal_dot_product." |
| | ); |
| | } |