-
Jakub Homola authoredJakub Homola authored
myaxpby_cuda.hpp 2.33 KiB
#pragma once
#include <cstdlib>
#include <cuda_runtime.h>
template<typename T>
void launch_my_axpby(dim3 grid_dim, dim3 block_dim, size_t shmem_size, cudaStream_t stream, T alpha, T * x, T beta, T * y, size_t count);
template<typename T>
void launch_my_fill(dim3 grid_dim, dim3 block_dim, size_t shmem_size, cudaStream_t stream, T * x, size_t count);
template<typename T>
class myaxpby
{
public:
using value_type = T;
private:
static constexpr int rank_gpu_map[] = {2,3,0,1,6,7,4,5}; // assuming Karolina GPU
public:
myaxpby() { }
~myaxpby()
{
destroy();
}
myaxpby(const myaxpby & other) = delete;
myaxpby(myaxpby && other) = delete;
myaxpby & operator=(const myaxpby & other) = delete;
myaxpby & operator=(myaxpby && other) = delete;
void init(int mpi_rank, size_t size_bytes)
{
if(initialized) throw std::runtime_error("cannot double-initialize");
gpu_idx = rank_gpu_map[mpi_rank];
count = size_bytes / sizeof(T);
initialized = true;
alloc();
fill();
}
void alloc()
{
if(!initialized) return;
CHECK(cudaSetDevice(gpu_idx));
CHECK(cudaMalloc(&d_x, count * sizeof(T)));
CHECK(cudaMalloc(&d_y, count * sizeof(T)));
}
void fill()
{
if(!initialized) return;
CHECK(cudaSetDevice(gpu_idx));
launch_my_fill(10000, 256, 0, 0, d_x, count);
CHECK(cudaPeekAtLastError());
launch_my_fill(10000, 256, 0, 0, d_y, count);
CHECK(cudaPeekAtLastError());
CHECK(cudaStreamSynchronize(0));
}
void destroy()
{
if(!initialized) return;
CHECK(cudaSetDevice(gpu_idx));
CHECK(cudaFree(d_x)); d_x = nullptr;
CHECK(cudaFree(d_y)); d_y = nullptr;
gpu_idx = -1;
count = 0;
initialized = false;
}
void perform()
{
if(!initialized) return;
CHECK(cudaSetDevice(gpu_idx));
T alpha = static_cast<float>(2.0 * rand() / RAND_MAX - 1.0);
T beta = static_cast<float>(2.0 * rand() / RAND_MAX - 1.0);
launch_my_axpby(10000, 256, 0, 0, alpha, d_x, beta, d_y, count);
CHECK(cudaPeekAtLastError());
CHECK(cudaStreamSynchronize(0));
}
private:
int gpu_idx = -1;
size_t count = 0;
T * d_x = nullptr;
T * d_y = nullptr;
bool initialized = false;
};