Spaces:
Running
on
T4
Running
on
T4
| typedef at::Half fp16; | |
| typedef at::BFloat16 bf16; | |
| typedef float fp32; | |
| template <typename F> | |
| __global__ void kernel_forward(const int B, const int T, const int C, const int H, | |
| float *__restrict__ _state, const F *__restrict__ const _r, const F *__restrict__ const _w, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _a, const F *__restrict__ const _b, | |
| F *__restrict__ const _y) | |
| { | |
| const int e = blockIdx.x / H; | |
| const int h = blockIdx.x % H; | |
| const int i = threadIdx.x; | |
| _state += h*_N_*_N_ + i*_N_; // wrong if B > 1 !!! | |
| float state[_N_]; | |
| for (int j = 0; j < _N_; j++) | |
| state[j] = _state[j]; | |
| __shared__ float r[_N_], k[_N_], w[_N_], a[_N_], b[_N_]; | |
| for (int _t = 0; _t < T; _t++) | |
| { | |
| const int t = e*T*C + h*_N_ + i + _t * C; | |
| __syncthreads(); | |
| r[i] = float(_r[t]); | |
| w[i] = __expf(-__expf(float(_w[t]))); | |
| k[i] = float(_k[t]); | |
| a[i] = float(_a[t]); | |
| b[i] = float(_b[t]); | |
| __syncthreads(); | |
| float sa = 0; | |
| for (int j = 0; j < _N_; j++) | |
| { | |
| sa += a[j] * state[j]; | |
| } | |
| float vv = float(_v[t]); | |
| float y = 0; | |
| for (int j = 0; j < _N_; j++) | |
| { | |
| float& s = state[j]; | |
| s = s * w[j] + k[j] * vv + sa * b[j]; | |
| y += s * r[j]; | |
| } | |
| _y[t] = F(y); | |
| } | |
| for (int j = 0; j < _N_; j++) | |
| _state[j] = state[j]; | |
| } | |
| void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16* w, bf16 *k, bf16 *v, bf16 *a, bf16 *b, bf16 *y) | |
| { | |
| assert(H*_N_ == C); | |
| assert(B == 1); // only for B=1 | |
| kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, w, k, v, a, b, y); | |
| } | |
| void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16* w, fp16 *k, fp16 *v, fp16 *a, fp16 *b, fp16 *y) | |
| { | |
| assert(H*_N_ == C); | |
| assert(B == 1); // only for B=1 | |
| kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, w, k, v, a, b, y); | |
| } | |
| void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32* w, fp32 *k, fp32 *v, fp32 *a, fp32 *b, fp32 *y) | |
| { | |
| assert(H*_N_ == C); | |
| assert(B == 1); // only for B=1 | |
| kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, w, k, v, a, b, y); | |
| } | |