|
|
- # cython: infer_types=True
- # cython: boundscheck=False
- # Copyright ExplsionAI GmbH, released under BSD.
-
- import atexit
-
-
- cdef extern from "blis.h" nogil:
- enum blis_err_t "err_t":
- pass
-
-
- cdef struct blis_cntx_t "cntx_t":
- pass
-
- cdef struct blis_rntm_t "rntm_s":
- pass
-
-
- ctypedef enum blis_trans_t "trans_t":
- BLIS_NO_TRANSPOSE
- BLIS_TRANSPOSE
- BLIS_CONJ_NO_TRANSPOSE
- BLIS_CONJ_TRANSPOSE
-
- ctypedef enum blis_conj_t "conj_t":
- BLIS_NO_CONJUGATE
- BLIS_CONJUGATE
-
- ctypedef enum blis_side_t "side_t":
- BLIS_LEFT
- BLIS_RIGHT
-
- ctypedef enum blis_uplo_t "uplo_t":
- BLIS_LOWER
- BLIS_UPPER
- BLIS_DENSE
-
- ctypedef enum blis_diag_t "diag_t":
- BLIS_NONUNIT_DIAG
- BLIS_UNIT_DIAG
-
- char* bli_info_get_int_type_size_str()
-
- blis_err_t bli_init()
- blis_err_t bli_finalize()
-
- blis_err_t bli_rntm_init(blis_rntm_t* rntm);
-
- # BLAS level 3 routines
- void bli_dgemm_ex(
- blis_trans_t transa,
- blis_trans_t transb,
- dim_t m,
- dim_t n,
- dim_t k,
- double* alpha,
- double* a, inc_t rsa, inc_t csa,
- double* b, inc_t rsb, inc_t csb,
- double* beta,
- double* c, inc_t rsc, inc_t csc,
- blis_cntx_t* cntx,
- blis_rntm_t* rntm,
- )
- # BLAS level 3 routines
- void bli_sgemm_ex(
- blis_trans_t transa,
- blis_trans_t transb,
- dim_t m,
- dim_t n,
- dim_t k,
- float* alpha,
- float* a, inc_t rsa, inc_t csa,
- float* b, inc_t rsb, inc_t csb,
- float* beta,
- float* c, inc_t rsc, inc_t csc,
- blis_cntx_t* cntx,
- blis_rntm_t* rntm,
- )
-
- void bli_dger_ex(
- blis_conj_t conjx,
- blis_conj_t conjy,
- dim_t m,
- dim_t n,
- double* alpha,
- double* x, inc_t incx,
- double* y, inc_t incy,
- double* a, inc_t rsa, inc_t csa,
- blis_cntx_t* cntx,
- blis_rntm_t* rntm,
- )
-
- void bli_sger_ex(
- blis_conj_t conjx,
- blis_conj_t conjy,
- dim_t m,
- dim_t n,
- float* alpha,
- float* x, inc_t incx,
- float* y, inc_t incy,
- float* a, inc_t rsa, inc_t csa,
- blis_cntx_t* cntx,
- blis_rntm_t* rntm,
- )
-
- void bli_dgemv_ex(
- blis_trans_t transa,
- blis_conj_t conjx,
- dim_t m,
- dim_t n,
- double* alpha,
- double* a, inc_t rsa, inc_t csa,
- double* x, inc_t incx,
- double* beta,
- double* y, inc_t incy,
- blis_cntx_t* cntx,
- blis_rntm_t* rntm,
- )
-
- void bli_sgemv_ex(
- blis_trans_t transa,
- blis_conj_t conjx,
- dim_t m,
- dim_t n,
- float* alpha,
- float* a, inc_t rsa, inc_t csa,
- float* x, inc_t incx,
- float* beta,
- float* y, inc_t incy,
- blis_cntx_t* cntx,
- blis_rntm_t* rntm,
- )
-
- void bli_daxpyv_ex(
- blis_conj_t conjx,
- dim_t m,
- double* alpha,
- double* x, inc_t incx,
- double* y, inc_t incy,
- blis_cntx_t* cntx,
- blis_rntm_t* rntm,
- )
-
- void bli_saxpyv_ex(
- blis_conj_t conjx,
- dim_t m,
- float* alpha,
- float* x, inc_t incx,
- float* y, inc_t incy,
- blis_cntx_t* cntx,
- blis_rntm_t* rntm,
- )
-
- void bli_dscalv_ex(
- blis_conj_t conjalpha,
- dim_t m,
- double* alpha,
- double* x, inc_t incx,
- blis_cntx_t* cntx,
- blis_rntm_t* rntm,
- )
-
- void bli_sscalv_ex(
- blis_conj_t conjalpha,
- dim_t m,
- float* alpha,
- float* x, inc_t incx,
- blis_cntx_t* cntx,
- blis_rntm_t* rntm,
- )
-
- void bli_ddotv_ex(
- blis_conj_t conjx,
- blis_conj_t conjy,
- dim_t m,
- double* x, inc_t incx,
- double* y, inc_t incy,
- double* rho,
- blis_cntx_t* cntx,
- blis_rntm_t* rntm,
- )
-
- void bli_sdotv_ex(
- blis_conj_t conjx,
- blis_conj_t conjy,
- dim_t m,
- float* x, inc_t incx,
- float* y, inc_t incy,
- float* rho,
- blis_cntx_t* cntx,
- blis_rntm_t* rntm,
- )
-
- void bli_snorm1v_ex(
- dim_t n,
- float* x, inc_t incx,
- float* norm,
- blis_cntx_t* cntx,
- blis_rntm_t* rntm,
- )
-
- void bli_dnorm1v_ex(
- dim_t n,
- double* x, inc_t incx,
- double* norm,
- blis_cntx_t* cntx,
- blis_rntm_t* rntm,
- )
-
- void bli_snormfv_ex(
- dim_t n,
- float* x, inc_t incx,
- float* norm,
- blis_cntx_t* cntx,
- blis_rntm_t* rntm,
- )
-
- void bli_dnormfv_ex(
- dim_t n,
- double* x, inc_t incx,
- double* norm,
- blis_cntx_t* cntx,
- blis_rntm_t* rntm,
- )
-
- void bli_snormiv_ex(
- dim_t n,
- float* x, inc_t incx,
- float* norm,
- blis_cntx_t* cntx,
- blis_rntm_t* rntm,
- )
-
- void bli_dnormiv_ex(
- dim_t n,
- double* x, inc_t incx,
- double* norm,
- blis_cntx_t* cntx,
- blis_rntm_t* rntm,
- )
-
- void bli_srandv_ex(
- dim_t m,
- float* x, inc_t incx,
- blis_cntx_t* cntx,
- blis_rntm_t* rntm,
- )
-
- void bli_drandv_ex(
- dim_t m,
- double* x, inc_t incx,
- blis_cntx_t* cntx,
- blis_rntm_t* rntm,
- )
-
- void bli_ssumsqv_ex(
- dim_t m,
- float* x, inc_t incx,
- float* scale,
- float* sumsq,
- blis_cntx_t* cntx,
- blis_rntm_t* rntm,
- ) nogil
-
- void bli_dsumsqv_ex(
- dim_t m,
- double* x, inc_t incx,
- double* scale,
- double* sumsq,
- blis_cntx_t* cntx,
- blis_rntm_t* rntm,
- ) nogil
-
-
-
- bli_init()
- cdef blis_rntm_t rntm;
-
- def init():
- bli_init()
- bli_rntm_init(&rntm);
- assert BLIS_NO_TRANSPOSE == <blis_trans_t>NO_TRANSPOSE
- assert BLIS_TRANSPOSE == <blis_trans_t>TRANSPOSE
- assert BLIS_CONJ_NO_TRANSPOSE == <blis_trans_t>CONJ_NO_TRANSPOSE
- assert BLIS_CONJ_TRANSPOSE == <blis_trans_t>CONJ_TRANSPOSE
- assert BLIS_NO_CONJUGATE == <blis_conj_t>NO_CONJUGATE
- assert BLIS_CONJUGATE == <blis_conj_t>CONJUGATE
- assert BLIS_LEFT == <blis_side_t>LEFT
- assert BLIS_RIGHT == <blis_side_t>RIGHT
- assert BLIS_LOWER == <blis_uplo_t>LOWER
- assert BLIS_UPPER == <blis_uplo_t>UPPER
- assert BLIS_DENSE == <blis_uplo_t>DENSE
- assert BLIS_NONUNIT_DIAG == <blis_diag_t>NONUNIT_DIAG
- assert BLIS_UNIT_DIAG == <blis_diag_t>UNIT_DIAG
-
-
- def get_int_type_size():
- cdef char* int_size = bli_info_get_int_type_size_str()
- return '%d' % int_size[0]
-
-
- # BLAS level 3 routines
- cdef void gemm(
- trans_t trans_a,
- trans_t trans_b,
- dim_t m,
- dim_t n,
- dim_t k,
- double alpha,
- reals_ft a, inc_t rsa, inc_t csa,
- reals_ft b, inc_t rsb, inc_t csb,
- double beta,
- reals_ft c, inc_t rsc, inc_t csc
- ) nogil:
- cdef float alpha_f = alpha
- cdef float beta_f = beta
- cdef double alpha_d = alpha
- cdef double beta_d = beta
- if reals_ft is floats_t:
- bli_sgemm_ex(
- <blis_trans_t>trans_a, <blis_trans_t>trans_b,
- m, n, k,
- &alpha_f, a, rsa, csa, b, rsb, csb, &beta_f, c, rsc, csc, NULL, &rntm)
- elif reals_ft is doubles_t:
- bli_dgemm_ex(
- <blis_trans_t>trans_a, <blis_trans_t>trans_b,
- m, n, k,
- &alpha_d, a, rsa, csa, b, rsb, csb, &beta_d, c, rsc, csc, NULL, &rntm)
- elif reals_ft is float1d_t:
- bli_sgemm_ex(
- <blis_trans_t>trans_a, <blis_trans_t>trans_b,
- m, n, k,
- &alpha_f, &a[0], rsa, csa, &b[0], rsb, csb, &beta_f, &c[0],
- rsc, csc, NULL, &rntm)
- elif reals_ft is double1d_t:
- bli_dgemm_ex(
- <blis_trans_t>trans_a, <blis_trans_t>trans_b,
- m, n, k,
- &alpha_d, &a[0], rsa, csa, &b[0], rsb, csb, &beta_d, &c[0],
- rsc, csc, NULL, &rntm)
- else:
- # Impossible --- panic?
- pass
-
-
- cdef void ger(
- conj_t conjx,
- conj_t conjy,
- dim_t m,
- dim_t n,
- double alpha,
- reals_ft x, inc_t incx,
- reals_ft y, inc_t incy,
- reals_ft a, inc_t rsa, inc_t csa
- ) nogil:
- cdef float alpha_f = alpha
- cdef double alpha_d = alpha
- if reals_ft is floats_t:
- bli_sger_ex(
- <blis_conj_t>conjx, <blis_conj_t>conjy,
- m, n,
- &alpha_f,
- x, incx, y, incy, a, rsa, csa, NULL, &rntm)
- elif reals_ft is doubles_t:
- bli_dger_ex(
- <blis_conj_t>conjx, <blis_conj_t>conjy,
- m, n,
- &alpha_d,
- x, incx, y, incy, a, rsa, csa, NULL, &rntm)
- elif reals_ft is float1d_t:
- bli_sger_ex(
- <blis_conj_t>conjx, <blis_conj_t>conjy,
- m, n,
- &alpha_f,
- &x[0], incx, &y[0], incy, &a[0], rsa, csa, NULL, &rntm)
- elif reals_ft is double1d_t:
- bli_dger_ex(
- <blis_conj_t>conjx, <blis_conj_t>conjy,
- m, n,
- &alpha_d,
- &x[0], incx, &y[0], incy, &a[0], rsa, csa, NULL, &rntm)
- else:
- # Impossible --- panic?
- pass
-
-
- cdef void gemv(
- trans_t transa,
- conj_t conjx,
- dim_t m,
- dim_t n,
- real_ft alpha,
- reals_ft a, inc_t rsa, inc_t csa,
- reals_ft x, inc_t incx,
- real_ft beta,
- reals_ft y, inc_t incy
- ) nogil:
- cdef float alpha_f = alpha
- cdef double alpha_d = alpha
- cdef float beta_f = alpha
- cdef double beta_d = alpha
- if reals_ft is floats_t:
- bli_sgemv_ex(
- <blis_trans_t>transa, <blis_conj_t>conjx,
- m, n,
- &alpha_f, a, rsa, csa,
- x, incx, &beta_f,
- y, incy, NULL, &rntm)
- elif reals_ft is doubles_t:
- bli_dgemv_ex(
- <blis_trans_t>transa, <blis_conj_t>conjx,
- m, n,
- &alpha_d, a, rsa, csa,
- x, incx, &beta_d,
- y, incy, NULL, &rntm)
- elif reals_ft is float1d_t:
- bli_sgemv_ex(
- <blis_trans_t>transa, <blis_conj_t>conjx,
- m, n,
- &alpha_f, &a[0], rsa, csa,
- &x[0], incx, &beta_f,
- &y[0], incy, NULL, &rntm)
- elif reals_ft is double1d_t:
- bli_dgemv_ex(
- <blis_trans_t>transa, <blis_conj_t>conjx,
- m, n,
- &alpha_d, &a[0], rsa, csa,
- &x[0], incx, &beta_d,
- &y[0], incy, NULL, &rntm)
- else:
- # Impossible --- panic?
- pass
-
-
- cdef void axpyv(
- conj_t conjx,
- dim_t m,
- real_ft alpha,
- reals_ft x, inc_t incx,
- reals_ft y, inc_t incy
- ) nogil:
- cdef float alpha_f = alpha
- cdef double alpha_d = alpha
- if reals_ft is floats_t:
- bli_saxpyv_ex(<blis_conj_t>conjx, m, &alpha_f, x, incx, y, incy, NULL, &rntm)
- elif reals_ft is doubles_t:
- bli_daxpyv_ex(<blis_conj_t>conjx, m, &alpha_d, x, incx, y, incy, NULL, &rntm)
- elif reals_ft is float1d_t:
- bli_saxpyv_ex(<blis_conj_t>conjx, m, &alpha_f, &x[0], incx, &y[0], incy, NULL, &rntm)
- elif reals_ft is double1d_t:
- bli_daxpyv_ex(<blis_conj_t>conjx, m, &alpha_d, &x[0], incx, &y[0], incy, NULL, &rntm)
- else:
- # Impossible --- panic?
- pass
-
-
- cdef void scalv(
- conj_t conjalpha,
- dim_t m,
- real_ft alpha,
- reals_ft x, inc_t incx
- ) nogil:
- cdef float alpha_f = alpha
- cdef double alpha_d = alpha
- if reals_ft is floats_t:
- bli_sscalv_ex(<blis_conj_t>conjalpha, m, &alpha_f, x, incx, NULL, &rntm)
- elif reals_ft is doubles_t:
- bli_dscalv_ex(<blis_conj_t>conjalpha, m, &alpha_d, x, incx, NULL, &rntm)
- elif reals_ft is float1d_t:
- bli_sscalv_ex(<blis_conj_t>conjalpha, m, &alpha_f, &x[0], incx, NULL, &rntm)
- elif reals_ft is double1d_t:
- bli_dscalv_ex(<blis_conj_t>conjalpha, m, &alpha_d, &x[0], incx, NULL, &rntm)
- else:
- # Impossible --- panic?
- pass
-
-
- cdef double norm_L1(
- dim_t n,
- reals_ft x, inc_t incx
- ) nogil:
- cdef double dnorm = 0
- cdef float snorm = 0
- if reals_ft is floats_t:
- bli_snorm1v_ex(n, x, incx, &snorm, NULL, &rntm)
- dnorm = snorm
- elif reals_ft is doubles_t:
- bli_dnorm1v_ex(n, x, incx, &dnorm, NULL, &rntm)
- elif reals_ft is float1d_t:
- bli_snorm1v_ex(n, &x[0], incx, &snorm, NULL, &rntm)
- dnorm = snorm
- elif reals_ft is double1d_t:
- bli_dnorm1v_ex(n, &x[0], incx, &dnorm, NULL, &rntm)
- else:
- # Impossible --- panic?
- pass
- return dnorm
-
-
- cdef double norm_L2(
- dim_t n,
- reals_ft x, inc_t incx
- ) nogil:
- cdef double dnorm = 0
- cdef float snorm = 0
- if reals_ft is floats_t:
- bli_snormfv_ex(n, x, incx, &snorm, NULL, &rntm)
- dnorm = snorm
- elif reals_ft is doubles_t:
- bli_dnormfv_ex(n, x, incx, &dnorm, NULL, &rntm)
- elif reals_ft is float1d_t:
- bli_snormfv_ex(n, &x[0], incx, &snorm, NULL, &rntm)
- dnorm = snorm
- elif reals_ft is double1d_t:
- bli_dnormfv_ex(n, &x[0], incx, &dnorm, NULL, &rntm)
- else:
- # Impossible --- panic?
- pass
- return dnorm
-
-
- cdef double norm_inf(
- dim_t n,
- reals_ft x, inc_t incx
- ) nogil:
- cdef double dnorm = 0
- cdef float snorm = 0
- if reals_ft is floats_t:
- bli_snormiv_ex(n, x, incx, &snorm, NULL, &rntm)
- dnorm = snorm
- elif reals_ft is doubles_t:
- bli_dnormiv_ex(n, x, incx, &dnorm, NULL, &rntm)
- elif reals_ft is float1d_t:
- bli_snormiv_ex(n, &x[0], incx, &snorm, NULL, &rntm)
- dnorm = snorm
- elif reals_ft is double1d_t:
- bli_dnormiv_ex(n, &x[0], incx, &dnorm, NULL, &rntm)
- else:
- # Impossible --- panic?
- pass
- return dnorm
-
-
- cdef double dotv(
- conj_t conjx,
- conj_t conjy,
- dim_t m,
- reals_ft x,
- reals_ft y,
- inc_t incx,
- inc_t incy,
- ) nogil:
- cdef double rho_d = 0.0
- cdef float rho_f = 0.0
- if reals_ft is floats_t:
- bli_sdotv_ex(<blis_conj_t>conjx, <blis_conj_t>conjy, m, x, incx, y, incy, &rho_f, NULL, &rntm)
- return rho_f
- elif reals_ft is doubles_t:
- bli_ddotv_ex(<blis_conj_t>conjx, <blis_conj_t>conjy, m, x, incx, y, incy, &rho_d, NULL, &rntm)
- return rho_d
- elif reals_ft is float1d_t:
- bli_sdotv_ex(<blis_conj_t>conjx, <blis_conj_t>conjy, m, &x[0], incx, &y[0], incy,
- &rho_f, NULL, &rntm)
- return rho_f
- elif reals_ft is double1d_t:
- bli_ddotv_ex(<blis_conj_t>conjx, <blis_conj_t>conjy, m, &x[0], incx, &y[0], incy,
- &rho_d, NULL, &rntm)
- return rho_d
- else:
- raise ValueError("Unhandled fused type")
-
-
- cdef void randv(dim_t m, reals_ft x, inc_t incx) nogil:
- if reals_ft is floats_t:
- bli_srandv_ex(m, x, incx, NULL, &rntm)
- elif reals_ft is float1d_t:
- bli_srandv_ex(m, &x[0], incx, NULL, &rntm)
- if reals_ft is doubles_t:
- bli_drandv_ex(m, x, incx, NULL, &rntm)
- elif reals_ft is double1d_t:
- bli_drandv_ex(m, &x[0], incx, NULL, &rntm)
- else:
- with gil:
- raise ValueError("Unhandled fused type")
-
-
- cdef void sumsqv(dim_t m, reals_ft x, inc_t incx,
- reals_ft scale, reals_ft sumsq) nogil:
- if reals_ft is floats_t:
- bli_ssumsqv_ex(m, &x[0], incx, scale, sumsq, NULL, &rntm)
- elif reals_ft is float1d_t:
- bli_ssumsqv_ex(m, &x[0], incx, &scale[0], &sumsq[0], NULL, &rntm)
- if reals_ft is doubles_t:
- bli_dsumsqv_ex(m, x, incx, scale, sumsq, NULL, &rntm)
- elif reals_ft is double1d_t:
- bli_dsumsqv_ex(m, &x[0], incx, &scale[0], &sumsq[0], NULL, &rntm)
- else:
- with gil:
- raise ValueError("Unhandled fused type")
-
-
- @atexit.register
- def finalize():
- bli_finalize()
|