Commit efe9fabf authored by Lubomir Riha's avatar Lubomir Riha
Browse files

ENH: Forward substitution works in 3D Trans Thomas kernel

parent b85269af
......@@ -740,14 +740,14 @@ __global__ void thomas_kernel3D_X1(int const m, FT const alpha, FT const alpha_2
x[j] = work_buffer_reg;
}
FT x_reg = x[start + (m-1)*n];
x[start + (m-1)*n] = x_reg;
for (int i = m-2; i >= 0; --i) {
int j = start + i*n;
x_reg = x[j] - dev_c_prime[i] * x_reg;
x[j] = x_reg;
}
// FT x_reg = x[start + (m-1)*n];
// x[start + (m-1)*n] = x_reg;
//
// for (int i = m-2; i >= 0; --i) {
// int j = start + i*n;
// x_reg = x[j] - dev_c_prime[i] * x_reg;
// x[j] = x_reg;
// }
}
......@@ -773,14 +773,14 @@ __global__ void thomas_kernel3D_X2(int const m, FT const alpha, FT const alpha_2
x[j] = work_buffer_reg;
}
FT x_reg = x[start + (m-1)*m];
x[start + (m-1)*m] = x_reg;
for (int i = m-2; i >= 0; --i) {
int j = start + i*m;
x_reg = x[j] - dev_c_prime[i] * x_reg;
x[j] = x_reg;
}
// FT x_reg = x[start + (m-1)*m];
// x[start + (m-1)*m] = x_reg;
//
// for (int i = m-2; i >= 0; --i) {
// int j = start + i*m;
// x_reg = x[j] - dev_c_prime[i] * x_reg;
// x[j] = x_reg;
// }
}
......@@ -788,16 +788,21 @@ __global__ void thomas_kernel3D_X2(int const m, FT const alpha, FT const alpha_2
template<class FT>
__global__ void thomas_kernel3D_XT(int const m, FT const alpha, FT const alpha_23, FT const * const __restrict__ dev_c_prime, FT const * const __restrict__ b, FT * const __restrict__ x) {
#define TILE_SIZE 2
//printf("Tile size = %d \n", TILE_SIZE);
__shared__ FT sh_b[4][4];
__shared__ FT sh_x[4][4];
__shared__ FT sh_b[TILE_SIZE][TILE_SIZE];
__shared__ FT sh_x[TILE_SIZE][TILE_SIZE];
int n = m*m;
int TILES = m / TILE_SIZE;
//int n = m*m;
int tid_l = threadIdx.x;
int bid = blockIdx.x;
int bid = blockIdx.x;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int start = tid;
int base_addr = tid_l + m*m*TILE_SIZE*(bid%TILE_SIZE) + (bid/TILE_SIZE)*m; // + m*m*i
//FT work_buffer_reg = 0.0;
......@@ -805,36 +810,108 @@ __global__ void thomas_kernel3D_XT(int const m, FT const alpha, FT const alpha_2
// read first patch of input data
#pragma unroll
for (int i = 0; i < 4; i++) {
sh_b[tid_l][i] = b[ start + m*m*i ];
for (int i = 0; i < TILE_SIZE; i++) {
int a = base_addr + m*m*i;
sh_b[tid_l][i] = b[a];
//printf("tid = %d - SM a = [%d,%d] - g a = %d ; val = %f \n", tid, tid_l, i, a, b[ a ] );
}
FT work_buffer_reg = sh_b[0][tid_l] * alpha_23 / (FT(2) + alpha);
sh_x[0][tid_l] = work_buffer_reg;
//printf("A tid = %d - work_buffer_reg = %f ; in_val = %f \n", tid, work_buffer_reg, sh_b[0][tid_l]);
for (int i = 1; i < m; ++i) {
//int j = start + i*n;
#pragma unroll
for (int i = 1; i < TILE_SIZE; ++i) {
work_buffer_reg = (sh_b[i][tid_l] * alpha_23 + work_buffer_reg) / (FT(2) + alpha + dev_c_prime[i-1]);
sh_x[i][tid_l] = work_buffer_reg;
//printf("X tid = %d - work_buffer_reg = %f - prim a = %d \n", tid, work_buffer_reg, i-1);
}
FT x_reg = sh_x[m-1][tid_l];
sh_x[m-1][tid_l] = x_reg;
#pragma unroll
for (int i = 0; i < TILE_SIZE; i++) {
int a = base_addr + m*m*i;
x[ a ] = sh_x[tid_l][i];
//printf("tid = %d - SM a = [%d,%d] - g a = %d \n", tid, tid_l, i, a );
}
for (int tile = 1; tile < TILES; tile++) {
#pragma unroll
for (int i = 0; i < TILE_SIZE; i++) {
//int a = tid_l + m*m*TILE_SIZE*bid + m*m*i + tile * TILE_SIZE;
int a = base_addr + m*m*i + tile * TILE_SIZE;
sh_b[tid_l][i] = b[a];
//printf("tid = %d - SM a = [%d,%d] - g a = %d \n", tid, tid_l, i, a );
}
//printf("Y tid = %d - work_buffer_reg = %f \n", tid, work_buffer_reg);
#pragma unroll
for (int i = 0; i < TILE_SIZE; i++) {
work_buffer_reg = (sh_b[i][tid_l] * alpha_23 + work_buffer_reg) / (FT(2) + alpha + dev_c_prime[tile * TILE_SIZE + i - 1]);
sh_x[i][tid_l] = work_buffer_reg;
//printf("Z tid = %d - work_buffer_reg = %f - prim a = %d \n", tid, work_buffer_reg, tile * TILE_SIZE + i - 1);
}
#pragma unroll
for (int i = 0; i < TILE_SIZE; i++) {
int a = base_addr + m*m*i + tile * TILE_SIZE;
x[ a ] = sh_x[tid_l][i];
//printf("tid = %d - SM a = [%d,%d] - g a = %d \n", tid, tid_l, i , a );
}
for (int i = m-2; i >= 0; --i) {
int j = start + i*n;
x_reg = sh_x[i][tid_l] - dev_c_prime[i] * x_reg;
sh_x[i][tid_l] = x_reg;
}
// FT x_reg = sh_x[TILE_SIZE-1][tid_l];
// sh_x[TILE_SIZE-1][tid_l] = x_reg;
//
// for (int i = TILE_SIZE-2; i >= 0; --i) {
// x_reg = sh_x[i][tid_l] - dev_c_prime[m - TILE_SIZE + i] * x_reg;
// sh_x[i][tid_l] = x_reg;
// }
// write patch of input data
#pragma unroll
for (int i = 0; i < 4; i++) {
x[ start + m*m*i ] = sh_x[tid_l][i];
}
// for (int tile = TILES - 1; tile > 0; tile--) {
//
// #pragma unroll
// for (int i = 0; i < TILE_SIZE; i++) {
// sh_b[tid_l][i] = b[ start + m * m * ( tile * TILE_SIZE + i ) ];
// }
//
// #pragma unroll
// for (int i = 0; i < TILE_SIZE; ++i) {
// work_buffer_reg = sh_b[i][tid_l] - dev_c_prime[(1+tile) * TILE_SIZE - i] * work_buffer_reg;
// sh_x[i][tid_l] = work_buffer_reg;
// }
//
// #pragma unroll
// for (int i = 0; i < TILE_SIZE; i++) {
// x[ start + m * m * ( tile * TILE_SIZE + i ) ] = sh_x[tid_l][i];
// }
//
// }
//
// for (int i = m-2; i >= 0; --i) {
// int j = start + i*n;
// x_reg = sh_x[i][tid_l] - dev_c_prime[i] * x_reg;
// sh_x[i][tid_l] = x_reg;
// }
//
//
//
// // write patch of input data
// #pragma unroll
// for (int i = 0; i < 4; i++) {
// x[ start + m*m*i ] = sh_x[tid_l][i];
// }
return;
......@@ -1198,7 +1275,7 @@ public:
}
std::cout << std::endl;
thomas_kernel3D_XT<FT><<<block_count, threads_per_block>>>(m, alpha, alpha_23, c_prime, bbb, xx); //bb);
thomas_kernel3D_XT<FT><<<8, 2>>>(m, alpha, alpha_23, c_prime, bbb, xx); //bb);
cublas_transpose2(cublas_handle, m*m, m, xx, bbb);
cublas_transpose2(cublas_handle, m*m, m, bbb, xx);
......
No preview for this file type
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment