libflame revision_anchor
Functions
FLA_Gemm_external_gpu.c File Reference

(r)

Functions

FLA_Error FLA_Gemm_external_gpu (FLA_Trans transa, FLA_Trans transb, FLA_Obj alpha, FLA_Obj A, void *A_gpu, FLA_Obj B, void *B_gpu, FLA_Obj beta, FLA_Obj C, void *C_gpu)
 

Function Documentation

◆ FLA_Gemm_external_gpu()

FLA_Error FLA_Gemm_external_gpu ( FLA_Trans  transa,
FLA_Trans  transb,
FLA_Obj  alpha,
FLA_Obj  A,
void A_gpu,
FLA_Obj  B,
void B_gpu,
FLA_Obj  beta,
FLA_Obj  C,
void C_gpu 
)
18{
19 FLA_Datatype datatype;
20 int k_AB;
21 int m_A, n_A;
22 int m_C, n_C;
23 int ldim_A;
24 int ldim_B;
25 int ldim_C;
26 char blas_transa;
27 char blas_transb;
28
31
32 if ( FLA_Obj_has_zero_dim( C ) ) return FLA_SUCCESS;
33
35 {
37 return FLA_SUCCESS;
38 }
39
40 datatype = FLA_Obj_datatype( A );
41
42 m_A = FLA_Obj_length( A );
43 n_A = FLA_Obj_width( A );
45
47
48 m_C = FLA_Obj_length( C );
49 n_C = FLA_Obj_width( C );
51
53 k_AB = n_A;
54 else
55 k_AB = m_A;
56
59
60
61 switch( datatype ){
62
63 case FLA_FLOAT:
64 {
65 float *buff_alpha = ( float * ) FLA_FLOAT_PTR( alpha );
66 float *buff_beta = ( float * ) FLA_FLOAT_PTR( beta );
67
70 m_C,
71 n_C,
72 k_AB,
74 ( float * ) A_gpu, ldim_A,
75 ( float * ) B_gpu, ldim_B,
76 *buff_beta,
77 ( float * ) C_gpu, ldim_C );
78
79 break;
80 }
81
82 case FLA_DOUBLE:
83 {
84 double *buff_alpha = ( double * ) FLA_DOUBLE_PTR( alpha );
85 double *buff_beta = ( double * ) FLA_DOUBLE_PTR( beta );
86
89 m_C,
90 n_C,
91 k_AB,
93 ( double * ) A_gpu, ldim_A,
94 ( double * ) B_gpu, ldim_B,
95 *buff_beta,
96 ( double * ) C_gpu, ldim_C );
97
98 break;
99 }
100
101 case FLA_COMPLEX:
102 {
105
108 m_C,
109 n_C,
110 k_AB,
111 *buff_alpha,
112 ( cuComplex * ) A_gpu, ldim_A,
113 ( cuComplex * ) B_gpu, ldim_B,
114 *buff_beta,
115 ( cuComplex * ) C_gpu, ldim_C );
116
117 break;
118 }
119
121 {
124
127 m_C,
128 n_C,
129 k_AB,
130 *buff_alpha,
133 *buff_beta,
135
136 break;
137 }
138
139 }
140
141 return FLA_SUCCESS;
142}
FLA_Error FLA_Gemm_check(FLA_Trans transa, FLA_Trans transb, FLA_Obj alpha, FLA_Obj A, FLA_Obj B, FLA_Obj beta, FLA_Obj C)
Definition FLA_Gemm_check.c:13
FLA_Error FLA_Scal_external_gpu(FLA_Obj alpha, FLA_Obj A, void *A_gpu)
Definition FLA_Scal_external_gpu.c:17
dim_t FLA_Obj_width(FLA_Obj obj)
Definition FLA_Query.c:123
FLA_Bool FLA_Obj_has_zero_dim(FLA_Obj A)
Definition FLA_Query.c:400
dim_t FLA_Obj_length(FLA_Obj obj)
Definition FLA_Query.c:116
unsigned int FLA_Check_error_level(void)
Definition FLA_Check.c:18
void FLA_Param_map_flame_to_netlib_trans(FLA_Trans trans, void *blas_trans)
Definition FLA_Param.c:15
FLA_Datatype FLA_Obj_datatype(FLA_Obj obj)
Definition FLA_Query.c:13
int FLA_Datatype
Definition FLA_type_defs.h:49
int i
Definition bl1_axmyv2.c:145

References FLA_Check_error_level(), FLA_Gemm_check(), FLA_Obj_datatype(), FLA_Obj_has_zero_dim(), FLA_Obj_length(), FLA_Obj_width(), FLA_Param_map_flame_to_netlib_trans(), FLA_Scal_external_gpu(), and i.

Referenced by FLASH_Queue_exec_task_gpu().