Skip to content
Snippets Groups Projects
accumulator-mfma-cdna-f32.cpp 12.1 KiB
Newer Older
  • Learn to ignore specific revisions
  • //by Branislav Jansik, IT4Innovations, 2024
    #include <stdio.h>
    #include <stdlib.h>
    #include <math.h>
    #include <time.h>
    
    #include "hip/hip_runtime.h"
    #include "hip/hip_runtime_api.h"
    
    #define REPEAT10(x) x x x x x x x x x x
    
    #define PIPELEN 8
    
    
    double second(void)
    {
        struct timespec tp;
    
        clock_gettime(CLOCK_REALTIME,&tp);
        return ((double) tp.tv_sec + (double)tp.tv_nsec/1e9);
    }
    
    
    
    __global__ void mb_iter(float *br, float *ar)
    {
        int tid;
    
       
        tid  = hipThreadIdx_x + hipBlockIdx_x * hipBlockDim_x;
    
        asm volatile(
    		    
    		 //load matrix element inputs
                    "global_load_dword v0,  %0, off\n\t"
                    "global_load_dword v1,  %0, off offset:4\n\t"
                    "global_load_dword v2,  %0, off offset:8\n\t"
                    "global_load_dword v3,  %0, off offset:12\n\t"
                    "global_load_dword v4,  %0, off offset:16\n\t"
                    "global_load_dword v5,  %0, off offset:20\n\t"
                    "global_load_dword v6,  %0, off offset:24\n\t"
                    "global_load_dword v7,  %0, off offset:28\n\t"
                    "global_load_dword v8,  %0, off offset:32\n\t"
                    "global_load_dword v9,  %0, off offset:36\n\t"
                    "global_load_dword v10, %0, off offset:40\n\t"
                    "global_load_dword v11, %0, off offset:44\n\t"
                    "global_load_dword v12, %0, off offset:48\n\t"
                    "global_load_dword v13, %0, off offset:52\n\t"
                    "global_load_dword v14, %0, off offset:56\n\t"
                    "global_load_dword v15, %0, off offset:60\n\t"
    
    	        "s_waitcnt lgkmcnt(0), vmcnt(0)\n\t"
    
                    //negate matrix element inputs
                    "v_mul_f32 v16, v0, -1.0\n\t"
                    "v_mul_f32 v17, v1, -1.0\n\t"
                    "v_mul_f32 v18, v2, -1.0\n\t"
                    "v_mul_f32 v19, v3, -1.0\n\t"
                    "v_mul_f32 v20, v4, -1.0\n\t"
                    "v_mul_f32 v21, v5, -1.0\n\t"
                    "v_mul_f32 v22, v6, -1.0\n\t"
                    "v_mul_f32 v23, v7, -1.0\n\t"
    
    		//initialize accumulators to zero
    		"v_accvgpr_write_b32 a0,  0\n\t"
    		"v_accvgpr_write_b32 a1,  0\n\t"
    		"v_accvgpr_write_b32 a2,  0\n\t"
    		"v_accvgpr_write_b32 a3,  0\n\t"
    		"v_accvgpr_write_b32 a4,  0\n\t"
    		"v_accvgpr_write_b32 a5,  0\n\t"
    		"v_accvgpr_write_b32 a6,  0\n\t"
    		"v_accvgpr_write_b32 a7,  0\n\t"
    		"v_accvgpr_write_b32 a8,  0\n\t"
    		"v_accvgpr_write_b32 a9,  0\n\t"
    		"v_accvgpr_write_b32 a10, 0\n\t"
    		"v_accvgpr_write_b32 a11, 0\n\t"
    		"v_accvgpr_write_b32 a12, 0\n\t"
    		"v_accvgpr_write_b32 a13, 0\n\t"
    		"v_accvgpr_write_b32 a14, 0\n\t"
    		"v_accvgpr_write_b32 a15, 0\n\t"
    		"v_accvgpr_write_b32 a16, 0\n\t"
    		"v_accvgpr_write_b32 a17, 0\n\t"
    		"v_accvgpr_write_b32 a18, 0\n\t"
    		"v_accvgpr_write_b32 a19, 0\n\t"
    		"v_accvgpr_write_b32 a20, 0\n\t"
    		"v_accvgpr_write_b32 a21, 0\n\t"
    		"v_accvgpr_write_b32 a22, 0\n\t"
    		"v_accvgpr_write_b32 a23, 0\n\t"
    		"v_accvgpr_write_b32 a24, 0\n\t"
    		"v_accvgpr_write_b32 a25, 0\n\t"
    		"v_accvgpr_write_b32 a26, 0\n\t"
    		"v_accvgpr_write_b32 a27, 0\n\t"
    		"v_accvgpr_write_b32 a28, 0\n\t"
    		"v_accvgpr_write_b32 a29, 0\n\t"
    		"v_accvgpr_write_b32 a30, 0\n\t"
    		"v_accvgpr_write_b32 a31, 0\n\t"
                    
    
                     //zero out loop counter
    	    	"s_mov_b32 s8, 0;\n\t"
    
    
    		 
    		 //fma loop 
    	 	 "loop:\n\t"
     REPEAT10(
    	   	 REPEAT10(
                    "v_mfma_f32_16x16x4f32 a[0:3],   v0, v8,  a[0:3]\n\t"
                    "v_mfma_f32_16x16x4f32 a[4:7],   v1, v9,  a[4:7]\n\t"
                    "v_mfma_f32_16x16x4f32 a[8:11],  v2, v10, a[8:11]\n\t"
                    "v_mfma_f32_16x16x4f32 a[12:15], v3, v11, a[12:15]\n\t"
                    "v_mfma_f32_16x16x4f32 a[16:19], v4, v12, a[16:19]\n\t"
                    "v_mfma_f32_16x16x4f32 a[20:23], v5, v13, a[20:23]\n\t"
                    "v_mfma_f32_16x16x4f32 a[24:27], v6, v14, a[24:27]\n\t"
                    "v_mfma_f32_16x16x4f32 a[28:31], v7, v15, a[28:31]\n\t"
    		 )
    
    	   	 REPEAT10(
                    "v_mfma_f32_16x16x4f32 a[0:3],   v16, v8,  a[0:3]\n\t"
                    "v_mfma_f32_16x16x4f32 a[4:7],   v17, v9,  a[4:7]\n\t"
                    "v_mfma_f32_16x16x4f32 a[8:11],  v18, v10, a[8:11]\n\t"
                    "v_mfma_f32_16x16x4f32 a[12:15], v19, v11, a[12:15]\n\t"
                    "v_mfma_f32_16x16x4f32 a[16:19], v20, v12, a[16:19]\n\t"
                    "v_mfma_f32_16x16x4f32 a[20:23], v21, v13, a[20:23]\n\t"
                    "v_mfma_f32_16x16x4f32 a[24:27], v22, v14, a[24:27]\n\t"
                    "v_mfma_f32_16x16x4f32 a[28:31], v23, v15, a[28:31]\n\t"
    
                    "v_mfma_f32_16x16x4f32 a[0:3],   v16, v8,  a[0:3]\n\t"
                    "v_mfma_f32_16x16x4f32 a[4:7],   v17, v9,  a[4:7]\n\t"
                    "v_mfma_f32_16x16x4f32 a[8:11],  v18, v10, a[8:11]\n\t"
                    "v_mfma_f32_16x16x4f32 a[12:15], v19, v11, a[12:15]\n\t"
                    "v_mfma_f32_16x16x4f32 a[16:19], v20, v12, a[16:19]\n\t"
                    "v_mfma_f32_16x16x4f32 a[20:23], v21, v13, a[20:23]\n\t"
                    "v_mfma_f32_16x16x4f32 a[24:27], v22, v14, a[24:27]\n\t"
                    "v_mfma_f32_16x16x4f32 a[28:31], v23, v15, a[28:31]\n\t"
    		 )
    
    	   	 REPEAT10(
                    "v_mfma_f32_16x16x4f32 a[0:3],   v0, v8,  a[0:3]\n\t"
                    "v_mfma_f32_16x16x4f32 a[4:7],   v1, v9,  a[4:7]\n\t"
                    "v_mfma_f32_16x16x4f32 a[8:11],  v2, v10, a[8:11]\n\t"
                    "v_mfma_f32_16x16x4f32 a[12:15], v3, v11, a[12:15]\n\t"
                    "v_mfma_f32_16x16x4f32 a[16:19], v4, v12, a[16:19]\n\t"
                    "v_mfma_f32_16x16x4f32 a[20:23], v5, v13, a[20:23]\n\t"
                    "v_mfma_f32_16x16x4f32 a[24:27], v6, v14, a[24:27]\n\t"
                    "v_mfma_f32_16x16x4f32 a[28:31], v7, v15, a[28:31]\n\t"
    		 )
    )
    		 
    		"s_add_i32 s8, s8, 1\n\t"
    		"s_cmp_lt_i32 s8, 40000\n\t"
    		"s_cbranch_scc1 loop\n\t"
    
                     //mandatory nop
                    "s_nop 8\n\t"
    
    		 //offload
                    "v_accvgpr_read_b32 v16, a0\n\t"
                    "v_accvgpr_read_b32 v17, a1\n\t"
                    "v_accvgpr_read_b32 v18, a2\n\t"
                    "v_accvgpr_read_b32 v19, a3\n\t"
                    "v_accvgpr_read_b32 v20, a4\n\t"
                    "v_accvgpr_read_b32 v21, a5\n\t"
                    "v_accvgpr_read_b32 v22, a6\n\t"
                    "v_accvgpr_read_b32 v23, a7\n\t"
                    "v_accvgpr_read_b32 v24, a8\n\t"
                    "v_accvgpr_read_b32 v25, a9\n\t"
                    "v_accvgpr_read_b32 v26, a10\n\t"
                    "v_accvgpr_read_b32 v27, a11\n\t"
                    "v_accvgpr_read_b32 v28, a12\n\t"
                    "v_accvgpr_read_b32 v29, a13\n\t"
                    "v_accvgpr_read_b32 v30, a14\n\t"
                    "v_accvgpr_read_b32 v31, a15\n\t"
                    "v_accvgpr_read_b32 v32, a16\n\t"
                    "v_accvgpr_read_b32 v33, a17\n\t"
                    "v_accvgpr_read_b32 v34, a18\n\t"
                    "v_accvgpr_read_b32 v35, a19\n\t"
                    "v_accvgpr_read_b32 v36, a20\n\t"
                    "v_accvgpr_read_b32 v37, a21\n\t"
                    "v_accvgpr_read_b32 v38, a22\n\t"
                    "v_accvgpr_read_b32 v39, a23\n\t"
                    "v_accvgpr_read_b32 v40, a24\n\t"
                    "v_accvgpr_read_b32 v41, a25\n\t"
                    "v_accvgpr_read_b32 v42, a26\n\t"
                    "v_accvgpr_read_b32 v43, a27\n\t"
                    "v_accvgpr_read_b32 v44, a28\n\t"
                    "v_accvgpr_read_b32 v45, a29\n\t"
                    "v_accvgpr_read_b32 v46, a30\n\t"
                    "v_accvgpr_read_b32 v47, a31\n\t"
                     
                    "global_store_dword %1, v16, off\n\t"
                    "global_store_dword %1, v17, off offset:4\n\t"
                    "global_store_dword %1, v18, off offset:8\n\t"
                    "global_store_dword %1, v19, off offset:12\n\t"
                    "global_store_dword %1, v20, off offset:16\n\t"
                    "global_store_dword %1, v21, off offset:20\n\t"
                    "global_store_dword %1, v22, off offset:24\n\t"
                    "global_store_dword %1, v23, off offset:28\n\t"
                    "global_store_dword %1, v24, off offset:32\n\t"
                    "global_store_dword %1, v25, off offset:36\n\t"
                    "global_store_dword %1, v26, off offset:40\n\t"
                    "global_store_dword %1, v27, off offset:44\n\t"
                    "global_store_dword %1, v28, off offset:48\n\t"
                    "global_store_dword %1, v29, off offset:52\n\t"
                    "global_store_dword %1, v30, off offset:56\n\t"
                    "global_store_dword %1, v31, off offset:60\n\t"
    
                    "global_store_dword %1, v32, off offset:64\n\t"
                    "global_store_dword %1, v33, off offset:68\n\t"
                    "global_store_dword %1, v34, off offset:72\n\t"
                    "global_store_dword %1, v35, off offset:76\n\t"
                    "global_store_dword %1, v36, off offset:80\n\t"
                    "global_store_dword %1, v37, off offset:84\n\t"
                    "global_store_dword %1, v38, off offset:88\n\t"
                    "global_store_dword %1, v39, off offset:92\n\t"
                    "global_store_dword %1, v40, off offset:96\n\t"
                    "global_store_dword %1, v41, off offset:100\n\t"
                    "global_store_dword %1, v42, off offset:104\n\t"
                    "global_store_dword %1, v43, off offset:108\n\t"
                    "global_store_dword %1, v44, off offset:112\n\t"
                    "global_store_dword %1, v45, off offset:116\n\t"
                    "global_store_dword %1, v46, off offset:120\n\t"
                    "global_store_dword %1, v47, off offset:124\n\t"
    
    		 "s_waitcnt lgkmcnt(0),vmcnt(0)\n\t" 
    		 
    		 //inputs, outputs and clobbers
    		 :  : "v"(ar + 2*PIPELEN*tid), "v"(br + 4*PIPELEN*tid) :
                     "s8",
                     "v0","v1","v2","v3","v4","v5","v6","v7",
                     "v8","v9","v10","v11","v12","v13","v14","v15",
                     "v16","v17","v18","v19","v20","v21","v22","v23",
                     "v24","v25","v26","v27","v28","v29","v30","v31",
                     "v32","v33","v34","v35","v36","v37","v38","v39",
                     "v40","v41","v42","v43","v44","v45","v46","v47",
    		 "memory");
    
    }
    
    
    int main(int argc, char** argv){
       
        float *ar, *br; 	
    
        double t0, t;
        float ainc;
        int size, i, j , off , GPU_N, niter=1;
        int NBLOCKS=480, NTHREADS=256;
    
        //usage    
        if (argc==1) printf("Usage: %s [niter] [NBLOCKS NTHREADS]\n",argv[0]);
    
        //inits
        if (argc>1) niter = atoi(argv[1]);
        if (argc>2) { NBLOCKS=atoi(argv[2]); NTHREADS=atoi(argv[3]); }
    
        //get HIP capable device count
        if(hipGetDeviceCount(&GPU_N))
        { printf("HIP error: hipGetDeviceCount\n"); return(1); }
        
        //set size
        size = NBLOCKS*NTHREADS;
    
        printf("HIP capable device count: %i\n", GPU_N);
        printf("FLIPS: Single Precision MFMA 16x16x4 Instructions Per Second\nRun %d times, %d Blocks, %d threads\n\n",
    		            niter, NBLOCKS, NTHREADS);
    
        hipHostMalloc((void **)&ar, 2*GPU_N*PIPELEN*size*sizeof(float));    
        hipHostMalloc((void **)&br, 4*GPU_N*PIPELEN*size*sizeof(float));    
      
        //initialize ar array
        ainc = 2.25/(float) (2*GPU_N*PIPELEN*size);
        for (i=0;i<2*GPU_N*PIPELEN*size; i++)
            ar[i]  = -2.0 + ainc*i;
    
    
    
        //iterations loop
        for (i=0; i<niter; i++) {
    
    
            t0 = second();
    
            for (j=0; j < GPU_N; j++) {
             hipSetDevice(j);
             hipLaunchKernelGGL( mb_iter, dim3(NBLOCKS), dim3(NTHREADS), 0, 0, &br[4*j*PIPELEN*size],&ar[2*j*PIPELEN*size]);
            }
    	
            for (j=0; j< GPU_N; j++) {
             hipSetDevice(j);
             hipDeviceSynchronize();
            }
    
            t = second();
        	
            //FLIPS = GPU_N*4*PIPELEN*size*40000*100/1e9
    	printf("%d: Summary FLIPS: %6.4fG \n", i+1, (4*GPU_N*PIPELEN*size*6.25e-5/(t-t0)));
    
        }
    
        printf("\nOutputs (Should be all 0.0):\n");
        for(int i=0; i<4*GPU_N*PIPELEN*size; i+=4*GPU_N*PIPELEN*size/20)
        printf("br: %d: %f %f %f %f\n", i, br[i+0], br[i+1], br[i+2], br[i+3]);
    
        //dump flips and frequency data
        //no thread level measurement implemented yet
        /*
        if (argc>3) {
         FILE *fp0 = fopen(argv[2],"w");
         FILE *fp1 = fopen(argv[3],"w");
         fwrite((void *)aflips,  sizeof(double), GPU_N*niter*size, fp0);
         fwrite((void *)afreqs,  sizeof(double), GPU_N*niter*size, fp1);
         fclose(fp0); fclose(fp1);
         printf("Full data dump in %s and %s\n", argv[2], argv[3]);
        }
        */
    
        //free memory arrays
        hipHostFree(br);
        hipHostFree(ar);
    
        return 0;
    }