You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

605 lines
15 KiB

# cython: infer_types=True
# cython: boundscheck=False
# Copyright ExplsionAI GmbH, released under BSD.
import atexit
cdef extern from "blis.h" nogil:
enum blis_err_t "err_t":
pass
cdef struct blis_cntx_t "cntx_t":
pass
cdef struct blis_rntm_t "rntm_s":
pass
ctypedef enum blis_trans_t "trans_t":
BLIS_NO_TRANSPOSE
BLIS_TRANSPOSE
BLIS_CONJ_NO_TRANSPOSE
BLIS_CONJ_TRANSPOSE
ctypedef enum blis_conj_t "conj_t":
BLIS_NO_CONJUGATE
BLIS_CONJUGATE
ctypedef enum blis_side_t "side_t":
BLIS_LEFT
BLIS_RIGHT
ctypedef enum blis_uplo_t "uplo_t":
BLIS_LOWER
BLIS_UPPER
BLIS_DENSE
ctypedef enum blis_diag_t "diag_t":
BLIS_NONUNIT_DIAG
BLIS_UNIT_DIAG
char* bli_info_get_int_type_size_str()
blis_err_t bli_init()
blis_err_t bli_finalize()
blis_err_t bli_rntm_init(blis_rntm_t* rntm);
# BLAS level 3 routines
void bli_dgemm_ex(
blis_trans_t transa,
blis_trans_t transb,
dim_t m,
dim_t n,
dim_t k,
double* alpha,
double* a, inc_t rsa, inc_t csa,
double* b, inc_t rsb, inc_t csb,
double* beta,
double* c, inc_t rsc, inc_t csc,
blis_cntx_t* cntx,
blis_rntm_t* rntm,
)
# BLAS level 3 routines
void bli_sgemm_ex(
blis_trans_t transa,
blis_trans_t transb,
dim_t m,
dim_t n,
dim_t k,
float* alpha,
float* a, inc_t rsa, inc_t csa,
float* b, inc_t rsb, inc_t csb,
float* beta,
float* c, inc_t rsc, inc_t csc,
blis_cntx_t* cntx,
blis_rntm_t* rntm,
)
void bli_dger_ex(
blis_conj_t conjx,
blis_conj_t conjy,
dim_t m,
dim_t n,
double* alpha,
double* x, inc_t incx,
double* y, inc_t incy,
double* a, inc_t rsa, inc_t csa,
blis_cntx_t* cntx,
blis_rntm_t* rntm,
)
void bli_sger_ex(
blis_conj_t conjx,
blis_conj_t conjy,
dim_t m,
dim_t n,
float* alpha,
float* x, inc_t incx,
float* y, inc_t incy,
float* a, inc_t rsa, inc_t csa,
blis_cntx_t* cntx,
blis_rntm_t* rntm,
)
void bli_dgemv_ex(
blis_trans_t transa,
blis_conj_t conjx,
dim_t m,
dim_t n,
double* alpha,
double* a, inc_t rsa, inc_t csa,
double* x, inc_t incx,
double* beta,
double* y, inc_t incy,
blis_cntx_t* cntx,
blis_rntm_t* rntm,
)
void bli_sgemv_ex(
blis_trans_t transa,
blis_conj_t conjx,
dim_t m,
dim_t n,
float* alpha,
float* a, inc_t rsa, inc_t csa,
float* x, inc_t incx,
float* beta,
float* y, inc_t incy,
blis_cntx_t* cntx,
blis_rntm_t* rntm,
)
void bli_daxpyv_ex(
blis_conj_t conjx,
dim_t m,
double* alpha,
double* x, inc_t incx,
double* y, inc_t incy,
blis_cntx_t* cntx,
blis_rntm_t* rntm,
)
void bli_saxpyv_ex(
blis_conj_t conjx,
dim_t m,
float* alpha,
float* x, inc_t incx,
float* y, inc_t incy,
blis_cntx_t* cntx,
blis_rntm_t* rntm,
)
void bli_dscalv_ex(
blis_conj_t conjalpha,
dim_t m,
double* alpha,
double* x, inc_t incx,
blis_cntx_t* cntx,
blis_rntm_t* rntm,
)
void bli_sscalv_ex(
blis_conj_t conjalpha,
dim_t m,
float* alpha,
float* x, inc_t incx,
blis_cntx_t* cntx,
blis_rntm_t* rntm,
)
void bli_ddotv_ex(
blis_conj_t conjx,
blis_conj_t conjy,
dim_t m,
double* x, inc_t incx,
double* y, inc_t incy,
double* rho,
blis_cntx_t* cntx,
blis_rntm_t* rntm,
)
void bli_sdotv_ex(
blis_conj_t conjx,
blis_conj_t conjy,
dim_t m,
float* x, inc_t incx,
float* y, inc_t incy,
float* rho,
blis_cntx_t* cntx,
blis_rntm_t* rntm,
)
void bli_snorm1v_ex(
dim_t n,
float* x, inc_t incx,
float* norm,
blis_cntx_t* cntx,
blis_rntm_t* rntm,
)
void bli_dnorm1v_ex(
dim_t n,
double* x, inc_t incx,
double* norm,
blis_cntx_t* cntx,
blis_rntm_t* rntm,
)
void bli_snormfv_ex(
dim_t n,
float* x, inc_t incx,
float* norm,
blis_cntx_t* cntx,
blis_rntm_t* rntm,
)
void bli_dnormfv_ex(
dim_t n,
double* x, inc_t incx,
double* norm,
blis_cntx_t* cntx,
blis_rntm_t* rntm,
)
void bli_snormiv_ex(
dim_t n,
float* x, inc_t incx,
float* norm,
blis_cntx_t* cntx,
blis_rntm_t* rntm,
)
void bli_dnormiv_ex(
dim_t n,
double* x, inc_t incx,
double* norm,
blis_cntx_t* cntx,
blis_rntm_t* rntm,
)
void bli_srandv_ex(
dim_t m,
float* x, inc_t incx,
blis_cntx_t* cntx,
blis_rntm_t* rntm,
)
void bli_drandv_ex(
dim_t m,
double* x, inc_t incx,
blis_cntx_t* cntx,
blis_rntm_t* rntm,
)
void bli_ssumsqv_ex(
dim_t m,
float* x, inc_t incx,
float* scale,
float* sumsq,
blis_cntx_t* cntx,
blis_rntm_t* rntm,
) nogil
void bli_dsumsqv_ex(
dim_t m,
double* x, inc_t incx,
double* scale,
double* sumsq,
blis_cntx_t* cntx,
blis_rntm_t* rntm,
) nogil
bli_init()
cdef blis_rntm_t rntm;
def init():
bli_init()
bli_rntm_init(&rntm);
assert BLIS_NO_TRANSPOSE == <blis_trans_t>NO_TRANSPOSE
assert BLIS_TRANSPOSE == <blis_trans_t>TRANSPOSE
assert BLIS_CONJ_NO_TRANSPOSE == <blis_trans_t>CONJ_NO_TRANSPOSE
assert BLIS_CONJ_TRANSPOSE == <blis_trans_t>CONJ_TRANSPOSE
assert BLIS_NO_CONJUGATE == <blis_conj_t>NO_CONJUGATE
assert BLIS_CONJUGATE == <blis_conj_t>CONJUGATE
assert BLIS_LEFT == <blis_side_t>LEFT
assert BLIS_RIGHT == <blis_side_t>RIGHT
assert BLIS_LOWER == <blis_uplo_t>LOWER
assert BLIS_UPPER == <blis_uplo_t>UPPER
assert BLIS_DENSE == <blis_uplo_t>DENSE
assert BLIS_NONUNIT_DIAG == <blis_diag_t>NONUNIT_DIAG
assert BLIS_UNIT_DIAG == <blis_diag_t>UNIT_DIAG
def get_int_type_size():
cdef char* int_size = bli_info_get_int_type_size_str()
return '%d' % int_size[0]
# BLAS level 3 routines
cdef void gemm(
trans_t trans_a,
trans_t trans_b,
dim_t m,
dim_t n,
dim_t k,
double alpha,
reals_ft a, inc_t rsa, inc_t csa,
reals_ft b, inc_t rsb, inc_t csb,
double beta,
reals_ft c, inc_t rsc, inc_t csc
) nogil:
cdef float alpha_f = alpha
cdef float beta_f = beta
cdef double alpha_d = alpha
cdef double beta_d = beta
if reals_ft is floats_t:
bli_sgemm_ex(
<blis_trans_t>trans_a, <blis_trans_t>trans_b,
m, n, k,
&alpha_f, a, rsa, csa, b, rsb, csb, &beta_f, c, rsc, csc, NULL, &rntm)
elif reals_ft is doubles_t:
bli_dgemm_ex(
<blis_trans_t>trans_a, <blis_trans_t>trans_b,
m, n, k,
&alpha_d, a, rsa, csa, b, rsb, csb, &beta_d, c, rsc, csc, NULL, &rntm)
elif reals_ft is float1d_t:
bli_sgemm_ex(
<blis_trans_t>trans_a, <blis_trans_t>trans_b,
m, n, k,
&alpha_f, &a[0], rsa, csa, &b[0], rsb, csb, &beta_f, &c[0],
rsc, csc, NULL, &rntm)
elif reals_ft is double1d_t:
bli_dgemm_ex(
<blis_trans_t>trans_a, <blis_trans_t>trans_b,
m, n, k,
&alpha_d, &a[0], rsa, csa, &b[0], rsb, csb, &beta_d, &c[0],
rsc, csc, NULL, &rntm)
else:
# Impossible --- panic?
pass
cdef void ger(
conj_t conjx,
conj_t conjy,
dim_t m,
dim_t n,
double alpha,
reals_ft x, inc_t incx,
reals_ft y, inc_t incy,
reals_ft a, inc_t rsa, inc_t csa
) nogil:
cdef float alpha_f = alpha
cdef double alpha_d = alpha
if reals_ft is floats_t:
bli_sger_ex(
<blis_conj_t>conjx, <blis_conj_t>conjy,
m, n,
&alpha_f,
x, incx, y, incy, a, rsa, csa, NULL, &rntm)
elif reals_ft is doubles_t:
bli_dger_ex(
<blis_conj_t>conjx, <blis_conj_t>conjy,
m, n,
&alpha_d,
x, incx, y, incy, a, rsa, csa, NULL, &rntm)
elif reals_ft is float1d_t:
bli_sger_ex(
<blis_conj_t>conjx, <blis_conj_t>conjy,
m, n,
&alpha_f,
&x[0], incx, &y[0], incy, &a[0], rsa, csa, NULL, &rntm)
elif reals_ft is double1d_t:
bli_dger_ex(
<blis_conj_t>conjx, <blis_conj_t>conjy,
m, n,
&alpha_d,
&x[0], incx, &y[0], incy, &a[0], rsa, csa, NULL, &rntm)
else:
# Impossible --- panic?
pass
cdef void gemv(
trans_t transa,
conj_t conjx,
dim_t m,
dim_t n,
real_ft alpha,
reals_ft a, inc_t rsa, inc_t csa,
reals_ft x, inc_t incx,
real_ft beta,
reals_ft y, inc_t incy
) nogil:
cdef float alpha_f = alpha
cdef double alpha_d = alpha
cdef float beta_f = alpha
cdef double beta_d = alpha
if reals_ft is floats_t:
bli_sgemv_ex(
<blis_trans_t>transa, <blis_conj_t>conjx,
m, n,
&alpha_f, a, rsa, csa,
x, incx, &beta_f,
y, incy, NULL, &rntm)
elif reals_ft is doubles_t:
bli_dgemv_ex(
<blis_trans_t>transa, <blis_conj_t>conjx,
m, n,
&alpha_d, a, rsa, csa,
x, incx, &beta_d,
y, incy, NULL, &rntm)
elif reals_ft is float1d_t:
bli_sgemv_ex(
<blis_trans_t>transa, <blis_conj_t>conjx,
m, n,
&alpha_f, &a[0], rsa, csa,
&x[0], incx, &beta_f,
&y[0], incy, NULL, &rntm)
elif reals_ft is double1d_t:
bli_dgemv_ex(
<blis_trans_t>transa, <blis_conj_t>conjx,
m, n,
&alpha_d, &a[0], rsa, csa,
&x[0], incx, &beta_d,
&y[0], incy, NULL, &rntm)
else:
# Impossible --- panic?
pass
cdef void axpyv(
conj_t conjx,
dim_t m,
real_ft alpha,
reals_ft x, inc_t incx,
reals_ft y, inc_t incy
) nogil:
cdef float alpha_f = alpha
cdef double alpha_d = alpha
if reals_ft is floats_t:
bli_saxpyv_ex(<blis_conj_t>conjx, m, &alpha_f, x, incx, y, incy, NULL, &rntm)
elif reals_ft is doubles_t:
bli_daxpyv_ex(<blis_conj_t>conjx, m, &alpha_d, x, incx, y, incy, NULL, &rntm)
elif reals_ft is float1d_t:
bli_saxpyv_ex(<blis_conj_t>conjx, m, &alpha_f, &x[0], incx, &y[0], incy, NULL, &rntm)
elif reals_ft is double1d_t:
bli_daxpyv_ex(<blis_conj_t>conjx, m, &alpha_d, &x[0], incx, &y[0], incy, NULL, &rntm)
else:
# Impossible --- panic?
pass
cdef void scalv(
conj_t conjalpha,
dim_t m,
real_ft alpha,
reals_ft x, inc_t incx
) nogil:
cdef float alpha_f = alpha
cdef double alpha_d = alpha
if reals_ft is floats_t:
bli_sscalv_ex(<blis_conj_t>conjalpha, m, &alpha_f, x, incx, NULL, &rntm)
elif reals_ft is doubles_t:
bli_dscalv_ex(<blis_conj_t>conjalpha, m, &alpha_d, x, incx, NULL, &rntm)
elif reals_ft is float1d_t:
bli_sscalv_ex(<blis_conj_t>conjalpha, m, &alpha_f, &x[0], incx, NULL, &rntm)
elif reals_ft is double1d_t:
bli_dscalv_ex(<blis_conj_t>conjalpha, m, &alpha_d, &x[0], incx, NULL, &rntm)
else:
# Impossible --- panic?
pass
cdef double norm_L1(
dim_t n,
reals_ft x, inc_t incx
) nogil:
cdef double dnorm = 0
cdef float snorm = 0
if reals_ft is floats_t:
bli_snorm1v_ex(n, x, incx, &snorm, NULL, &rntm)
dnorm = snorm
elif reals_ft is doubles_t:
bli_dnorm1v_ex(n, x, incx, &dnorm, NULL, &rntm)
elif reals_ft is float1d_t:
bli_snorm1v_ex(n, &x[0], incx, &snorm, NULL, &rntm)
dnorm = snorm
elif reals_ft is double1d_t:
bli_dnorm1v_ex(n, &x[0], incx, &dnorm, NULL, &rntm)
else:
# Impossible --- panic?
pass
return dnorm
cdef double norm_L2(
dim_t n,
reals_ft x, inc_t incx
) nogil:
cdef double dnorm = 0
cdef float snorm = 0
if reals_ft is floats_t:
bli_snormfv_ex(n, x, incx, &snorm, NULL, &rntm)
dnorm = snorm
elif reals_ft is doubles_t:
bli_dnormfv_ex(n, x, incx, &dnorm, NULL, &rntm)
elif reals_ft is float1d_t:
bli_snormfv_ex(n, &x[0], incx, &snorm, NULL, &rntm)
dnorm = snorm
elif reals_ft is double1d_t:
bli_dnormfv_ex(n, &x[0], incx, &dnorm, NULL, &rntm)
else:
# Impossible --- panic?
pass
return dnorm
cdef double norm_inf(
dim_t n,
reals_ft x, inc_t incx
) nogil:
cdef double dnorm = 0
cdef float snorm = 0
if reals_ft is floats_t:
bli_snormiv_ex(n, x, incx, &snorm, NULL, &rntm)
dnorm = snorm
elif reals_ft is doubles_t:
bli_dnormiv_ex(n, x, incx, &dnorm, NULL, &rntm)
elif reals_ft is float1d_t:
bli_snormiv_ex(n, &x[0], incx, &snorm, NULL, &rntm)
dnorm = snorm
elif reals_ft is double1d_t:
bli_dnormiv_ex(n, &x[0], incx, &dnorm, NULL, &rntm)
else:
# Impossible --- panic?
pass
return dnorm
cdef double dotv(
conj_t conjx,
conj_t conjy,
dim_t m,
reals_ft x,
reals_ft y,
inc_t incx,
inc_t incy,
) nogil:
cdef double rho_d = 0.0
cdef float rho_f = 0.0
if reals_ft is floats_t:
bli_sdotv_ex(<blis_conj_t>conjx, <blis_conj_t>conjy, m, x, incx, y, incy, &rho_f, NULL, &rntm)
return rho_f
elif reals_ft is doubles_t:
bli_ddotv_ex(<blis_conj_t>conjx, <blis_conj_t>conjy, m, x, incx, y, incy, &rho_d, NULL, &rntm)
return rho_d
elif reals_ft is float1d_t:
bli_sdotv_ex(<blis_conj_t>conjx, <blis_conj_t>conjy, m, &x[0], incx, &y[0], incy,
&rho_f, NULL, &rntm)
return rho_f
elif reals_ft is double1d_t:
bli_ddotv_ex(<blis_conj_t>conjx, <blis_conj_t>conjy, m, &x[0], incx, &y[0], incy,
&rho_d, NULL, &rntm)
return rho_d
else:
raise ValueError("Unhandled fused type")
cdef void randv(dim_t m, reals_ft x, inc_t incx) nogil:
if reals_ft is floats_t:
bli_srandv_ex(m, x, incx, NULL, &rntm)
elif reals_ft is float1d_t:
bli_srandv_ex(m, &x[0], incx, NULL, &rntm)
if reals_ft is doubles_t:
bli_drandv_ex(m, x, incx, NULL, &rntm)
elif reals_ft is double1d_t:
bli_drandv_ex(m, &x[0], incx, NULL, &rntm)
else:
with gil:
raise ValueError("Unhandled fused type")
cdef void sumsqv(dim_t m, reals_ft x, inc_t incx,
reals_ft scale, reals_ft sumsq) nogil:
if reals_ft is floats_t:
bli_ssumsqv_ex(m, &x[0], incx, scale, sumsq, NULL, &rntm)
elif reals_ft is float1d_t:
bli_ssumsqv_ex(m, &x[0], incx, &scale[0], &sumsq[0], NULL, &rntm)
if reals_ft is doubles_t:
bli_dsumsqv_ex(m, x, incx, scale, sumsq, NULL, &rntm)
elif reals_ft is double1d_t:
bli_dsumsqv_ex(m, &x[0], incx, &scale[0], &sumsq[0], NULL, &rntm)
else:
with gil:
raise ValueError("Unhandled fused type")
@atexit.register
def finalize():
bli_finalize()