You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

605 lines
15 KiB

4 years ago
  1. # cython: infer_types=True
  2. # cython: boundscheck=False
  3. # Copyright ExplsionAI GmbH, released under BSD.
  4. import atexit
  5. cdef extern from "blis.h" nogil:
  6. enum blis_err_t "err_t":
  7. pass
  8. cdef struct blis_cntx_t "cntx_t":
  9. pass
  10. cdef struct blis_rntm_t "rntm_s":
  11. pass
  12. ctypedef enum blis_trans_t "trans_t":
  13. BLIS_NO_TRANSPOSE
  14. BLIS_TRANSPOSE
  15. BLIS_CONJ_NO_TRANSPOSE
  16. BLIS_CONJ_TRANSPOSE
  17. ctypedef enum blis_conj_t "conj_t":
  18. BLIS_NO_CONJUGATE
  19. BLIS_CONJUGATE
  20. ctypedef enum blis_side_t "side_t":
  21. BLIS_LEFT
  22. BLIS_RIGHT
  23. ctypedef enum blis_uplo_t "uplo_t":
  24. BLIS_LOWER
  25. BLIS_UPPER
  26. BLIS_DENSE
  27. ctypedef enum blis_diag_t "diag_t":
  28. BLIS_NONUNIT_DIAG
  29. BLIS_UNIT_DIAG
  30. char* bli_info_get_int_type_size_str()
  31. blis_err_t bli_init()
  32. blis_err_t bli_finalize()
  33. blis_err_t bli_rntm_init(blis_rntm_t* rntm);
  34. # BLAS level 3 routines
  35. void bli_dgemm_ex(
  36. blis_trans_t transa,
  37. blis_trans_t transb,
  38. dim_t m,
  39. dim_t n,
  40. dim_t k,
  41. double* alpha,
  42. double* a, inc_t rsa, inc_t csa,
  43. double* b, inc_t rsb, inc_t csb,
  44. double* beta,
  45. double* c, inc_t rsc, inc_t csc,
  46. blis_cntx_t* cntx,
  47. blis_rntm_t* rntm,
  48. )
  49. # BLAS level 3 routines
  50. void bli_sgemm_ex(
  51. blis_trans_t transa,
  52. blis_trans_t transb,
  53. dim_t m,
  54. dim_t n,
  55. dim_t k,
  56. float* alpha,
  57. float* a, inc_t rsa, inc_t csa,
  58. float* b, inc_t rsb, inc_t csb,
  59. float* beta,
  60. float* c, inc_t rsc, inc_t csc,
  61. blis_cntx_t* cntx,
  62. blis_rntm_t* rntm,
  63. )
  64. void bli_dger_ex(
  65. blis_conj_t conjx,
  66. blis_conj_t conjy,
  67. dim_t m,
  68. dim_t n,
  69. double* alpha,
  70. double* x, inc_t incx,
  71. double* y, inc_t incy,
  72. double* a, inc_t rsa, inc_t csa,
  73. blis_cntx_t* cntx,
  74. blis_rntm_t* rntm,
  75. )
  76. void bli_sger_ex(
  77. blis_conj_t conjx,
  78. blis_conj_t conjy,
  79. dim_t m,
  80. dim_t n,
  81. float* alpha,
  82. float* x, inc_t incx,
  83. float* y, inc_t incy,
  84. float* a, inc_t rsa, inc_t csa,
  85. blis_cntx_t* cntx,
  86. blis_rntm_t* rntm,
  87. )
  88. void bli_dgemv_ex(
  89. blis_trans_t transa,
  90. blis_conj_t conjx,
  91. dim_t m,
  92. dim_t n,
  93. double* alpha,
  94. double* a, inc_t rsa, inc_t csa,
  95. double* x, inc_t incx,
  96. double* beta,
  97. double* y, inc_t incy,
  98. blis_cntx_t* cntx,
  99. blis_rntm_t* rntm,
  100. )
  101. void bli_sgemv_ex(
  102. blis_trans_t transa,
  103. blis_conj_t conjx,
  104. dim_t m,
  105. dim_t n,
  106. float* alpha,
  107. float* a, inc_t rsa, inc_t csa,
  108. float* x, inc_t incx,
  109. float* beta,
  110. float* y, inc_t incy,
  111. blis_cntx_t* cntx,
  112. blis_rntm_t* rntm,
  113. )
  114. void bli_daxpyv_ex(
  115. blis_conj_t conjx,
  116. dim_t m,
  117. double* alpha,
  118. double* x, inc_t incx,
  119. double* y, inc_t incy,
  120. blis_cntx_t* cntx,
  121. blis_rntm_t* rntm,
  122. )
  123. void bli_saxpyv_ex(
  124. blis_conj_t conjx,
  125. dim_t m,
  126. float* alpha,
  127. float* x, inc_t incx,
  128. float* y, inc_t incy,
  129. blis_cntx_t* cntx,
  130. blis_rntm_t* rntm,
  131. )
  132. void bli_dscalv_ex(
  133. blis_conj_t conjalpha,
  134. dim_t m,
  135. double* alpha,
  136. double* x, inc_t incx,
  137. blis_cntx_t* cntx,
  138. blis_rntm_t* rntm,
  139. )
  140. void bli_sscalv_ex(
  141. blis_conj_t conjalpha,
  142. dim_t m,
  143. float* alpha,
  144. float* x, inc_t incx,
  145. blis_cntx_t* cntx,
  146. blis_rntm_t* rntm,
  147. )
  148. void bli_ddotv_ex(
  149. blis_conj_t conjx,
  150. blis_conj_t conjy,
  151. dim_t m,
  152. double* x, inc_t incx,
  153. double* y, inc_t incy,
  154. double* rho,
  155. blis_cntx_t* cntx,
  156. blis_rntm_t* rntm,
  157. )
  158. void bli_sdotv_ex(
  159. blis_conj_t conjx,
  160. blis_conj_t conjy,
  161. dim_t m,
  162. float* x, inc_t incx,
  163. float* y, inc_t incy,
  164. float* rho,
  165. blis_cntx_t* cntx,
  166. blis_rntm_t* rntm,
  167. )
  168. void bli_snorm1v_ex(
  169. dim_t n,
  170. float* x, inc_t incx,
  171. float* norm,
  172. blis_cntx_t* cntx,
  173. blis_rntm_t* rntm,
  174. )
  175. void bli_dnorm1v_ex(
  176. dim_t n,
  177. double* x, inc_t incx,
  178. double* norm,
  179. blis_cntx_t* cntx,
  180. blis_rntm_t* rntm,
  181. )
  182. void bli_snormfv_ex(
  183. dim_t n,
  184. float* x, inc_t incx,
  185. float* norm,
  186. blis_cntx_t* cntx,
  187. blis_rntm_t* rntm,
  188. )
  189. void bli_dnormfv_ex(
  190. dim_t n,
  191. double* x, inc_t incx,
  192. double* norm,
  193. blis_cntx_t* cntx,
  194. blis_rntm_t* rntm,
  195. )
  196. void bli_snormiv_ex(
  197. dim_t n,
  198. float* x, inc_t incx,
  199. float* norm,
  200. blis_cntx_t* cntx,
  201. blis_rntm_t* rntm,
  202. )
  203. void bli_dnormiv_ex(
  204. dim_t n,
  205. double* x, inc_t incx,
  206. double* norm,
  207. blis_cntx_t* cntx,
  208. blis_rntm_t* rntm,
  209. )
  210. void bli_srandv_ex(
  211. dim_t m,
  212. float* x, inc_t incx,
  213. blis_cntx_t* cntx,
  214. blis_rntm_t* rntm,
  215. )
  216. void bli_drandv_ex(
  217. dim_t m,
  218. double* x, inc_t incx,
  219. blis_cntx_t* cntx,
  220. blis_rntm_t* rntm,
  221. )
  222. void bli_ssumsqv_ex(
  223. dim_t m,
  224. float* x, inc_t incx,
  225. float* scale,
  226. float* sumsq,
  227. blis_cntx_t* cntx,
  228. blis_rntm_t* rntm,
  229. ) nogil
  230. void bli_dsumsqv_ex(
  231. dim_t m,
  232. double* x, inc_t incx,
  233. double* scale,
  234. double* sumsq,
  235. blis_cntx_t* cntx,
  236. blis_rntm_t* rntm,
  237. ) nogil
  238. bli_init()
  239. cdef blis_rntm_t rntm;
  240. def init():
  241. bli_init()
  242. bli_rntm_init(&rntm);
  243. assert BLIS_NO_TRANSPOSE == <blis_trans_t>NO_TRANSPOSE
  244. assert BLIS_TRANSPOSE == <blis_trans_t>TRANSPOSE
  245. assert BLIS_CONJ_NO_TRANSPOSE == <blis_trans_t>CONJ_NO_TRANSPOSE
  246. assert BLIS_CONJ_TRANSPOSE == <blis_trans_t>CONJ_TRANSPOSE
  247. assert BLIS_NO_CONJUGATE == <blis_conj_t>NO_CONJUGATE
  248. assert BLIS_CONJUGATE == <blis_conj_t>CONJUGATE
  249. assert BLIS_LEFT == <blis_side_t>LEFT
  250. assert BLIS_RIGHT == <blis_side_t>RIGHT
  251. assert BLIS_LOWER == <blis_uplo_t>LOWER
  252. assert BLIS_UPPER == <blis_uplo_t>UPPER
  253. assert BLIS_DENSE == <blis_uplo_t>DENSE
  254. assert BLIS_NONUNIT_DIAG == <blis_diag_t>NONUNIT_DIAG
  255. assert BLIS_UNIT_DIAG == <blis_diag_t>UNIT_DIAG
  256. def get_int_type_size():
  257. cdef char* int_size = bli_info_get_int_type_size_str()
  258. return '%d' % int_size[0]
  259. # BLAS level 3 routines
  260. cdef void gemm(
  261. trans_t trans_a,
  262. trans_t trans_b,
  263. dim_t m,
  264. dim_t n,
  265. dim_t k,
  266. double alpha,
  267. reals_ft a, inc_t rsa, inc_t csa,
  268. reals_ft b, inc_t rsb, inc_t csb,
  269. double beta,
  270. reals_ft c, inc_t rsc, inc_t csc
  271. ) nogil:
  272. cdef float alpha_f = alpha
  273. cdef float beta_f = beta
  274. cdef double alpha_d = alpha
  275. cdef double beta_d = beta
  276. if reals_ft is floats_t:
  277. bli_sgemm_ex(
  278. <blis_trans_t>trans_a, <blis_trans_t>trans_b,
  279. m, n, k,
  280. &alpha_f, a, rsa, csa, b, rsb, csb, &beta_f, c, rsc, csc, NULL, &rntm)
  281. elif reals_ft is doubles_t:
  282. bli_dgemm_ex(
  283. <blis_trans_t>trans_a, <blis_trans_t>trans_b,
  284. m, n, k,
  285. &alpha_d, a, rsa, csa, b, rsb, csb, &beta_d, c, rsc, csc, NULL, &rntm)
  286. elif reals_ft is float1d_t:
  287. bli_sgemm_ex(
  288. <blis_trans_t>trans_a, <blis_trans_t>trans_b,
  289. m, n, k,
  290. &alpha_f, &a[0], rsa, csa, &b[0], rsb, csb, &beta_f, &c[0],
  291. rsc, csc, NULL, &rntm)
  292. elif reals_ft is double1d_t:
  293. bli_dgemm_ex(
  294. <blis_trans_t>trans_a, <blis_trans_t>trans_b,
  295. m, n, k,
  296. &alpha_d, &a[0], rsa, csa, &b[0], rsb, csb, &beta_d, &c[0],
  297. rsc, csc, NULL, &rntm)
  298. else:
  299. # Impossible --- panic?
  300. pass
  301. cdef void ger(
  302. conj_t conjx,
  303. conj_t conjy,
  304. dim_t m,
  305. dim_t n,
  306. double alpha,
  307. reals_ft x, inc_t incx,
  308. reals_ft y, inc_t incy,
  309. reals_ft a, inc_t rsa, inc_t csa
  310. ) nogil:
  311. cdef float alpha_f = alpha
  312. cdef double alpha_d = alpha
  313. if reals_ft is floats_t:
  314. bli_sger_ex(
  315. <blis_conj_t>conjx, <blis_conj_t>conjy,
  316. m, n,
  317. &alpha_f,
  318. x, incx, y, incy, a, rsa, csa, NULL, &rntm)
  319. elif reals_ft is doubles_t:
  320. bli_dger_ex(
  321. <blis_conj_t>conjx, <blis_conj_t>conjy,
  322. m, n,
  323. &alpha_d,
  324. x, incx, y, incy, a, rsa, csa, NULL, &rntm)
  325. elif reals_ft is float1d_t:
  326. bli_sger_ex(
  327. <blis_conj_t>conjx, <blis_conj_t>conjy,
  328. m, n,
  329. &alpha_f,
  330. &x[0], incx, &y[0], incy, &a[0], rsa, csa, NULL, &rntm)
  331. elif reals_ft is double1d_t:
  332. bli_dger_ex(
  333. <blis_conj_t>conjx, <blis_conj_t>conjy,
  334. m, n,
  335. &alpha_d,
  336. &x[0], incx, &y[0], incy, &a[0], rsa, csa, NULL, &rntm)
  337. else:
  338. # Impossible --- panic?
  339. pass
  340. cdef void gemv(
  341. trans_t transa,
  342. conj_t conjx,
  343. dim_t m,
  344. dim_t n,
  345. real_ft alpha,
  346. reals_ft a, inc_t rsa, inc_t csa,
  347. reals_ft x, inc_t incx,
  348. real_ft beta,
  349. reals_ft y, inc_t incy
  350. ) nogil:
  351. cdef float alpha_f = alpha
  352. cdef double alpha_d = alpha
  353. cdef float beta_f = alpha
  354. cdef double beta_d = alpha
  355. if reals_ft is floats_t:
  356. bli_sgemv_ex(
  357. <blis_trans_t>transa, <blis_conj_t>conjx,
  358. m, n,
  359. &alpha_f, a, rsa, csa,
  360. x, incx, &beta_f,
  361. y, incy, NULL, &rntm)
  362. elif reals_ft is doubles_t:
  363. bli_dgemv_ex(
  364. <blis_trans_t>transa, <blis_conj_t>conjx,
  365. m, n,
  366. &alpha_d, a, rsa, csa,
  367. x, incx, &beta_d,
  368. y, incy, NULL, &rntm)
  369. elif reals_ft is float1d_t:
  370. bli_sgemv_ex(
  371. <blis_trans_t>transa, <blis_conj_t>conjx,
  372. m, n,
  373. &alpha_f, &a[0], rsa, csa,
  374. &x[0], incx, &beta_f,
  375. &y[0], incy, NULL, &rntm)
  376. elif reals_ft is double1d_t:
  377. bli_dgemv_ex(
  378. <blis_trans_t>transa, <blis_conj_t>conjx,
  379. m, n,
  380. &alpha_d, &a[0], rsa, csa,
  381. &x[0], incx, &beta_d,
  382. &y[0], incy, NULL, &rntm)
  383. else:
  384. # Impossible --- panic?
  385. pass
  386. cdef void axpyv(
  387. conj_t conjx,
  388. dim_t m,
  389. real_ft alpha,
  390. reals_ft x, inc_t incx,
  391. reals_ft y, inc_t incy
  392. ) nogil:
  393. cdef float alpha_f = alpha
  394. cdef double alpha_d = alpha
  395. if reals_ft is floats_t:
  396. bli_saxpyv_ex(<blis_conj_t>conjx, m, &alpha_f, x, incx, y, incy, NULL, &rntm)
  397. elif reals_ft is doubles_t:
  398. bli_daxpyv_ex(<blis_conj_t>conjx, m, &alpha_d, x, incx, y, incy, NULL, &rntm)
  399. elif reals_ft is float1d_t:
  400. bli_saxpyv_ex(<blis_conj_t>conjx, m, &alpha_f, &x[0], incx, &y[0], incy, NULL, &rntm)
  401. elif reals_ft is double1d_t:
  402. bli_daxpyv_ex(<blis_conj_t>conjx, m, &alpha_d, &x[0], incx, &y[0], incy, NULL, &rntm)
  403. else:
  404. # Impossible --- panic?
  405. pass
  406. cdef void scalv(
  407. conj_t conjalpha,
  408. dim_t m,
  409. real_ft alpha,
  410. reals_ft x, inc_t incx
  411. ) nogil:
  412. cdef float alpha_f = alpha
  413. cdef double alpha_d = alpha
  414. if reals_ft is floats_t:
  415. bli_sscalv_ex(<blis_conj_t>conjalpha, m, &alpha_f, x, incx, NULL, &rntm)
  416. elif reals_ft is doubles_t:
  417. bli_dscalv_ex(<blis_conj_t>conjalpha, m, &alpha_d, x, incx, NULL, &rntm)
  418. elif reals_ft is float1d_t:
  419. bli_sscalv_ex(<blis_conj_t>conjalpha, m, &alpha_f, &x[0], incx, NULL, &rntm)
  420. elif reals_ft is double1d_t:
  421. bli_dscalv_ex(<blis_conj_t>conjalpha, m, &alpha_d, &x[0], incx, NULL, &rntm)
  422. else:
  423. # Impossible --- panic?
  424. pass
  425. cdef double norm_L1(
  426. dim_t n,
  427. reals_ft x, inc_t incx
  428. ) nogil:
  429. cdef double dnorm = 0
  430. cdef float snorm = 0
  431. if reals_ft is floats_t:
  432. bli_snorm1v_ex(n, x, incx, &snorm, NULL, &rntm)
  433. dnorm = snorm
  434. elif reals_ft is doubles_t:
  435. bli_dnorm1v_ex(n, x, incx, &dnorm, NULL, &rntm)
  436. elif reals_ft is float1d_t:
  437. bli_snorm1v_ex(n, &x[0], incx, &snorm, NULL, &rntm)
  438. dnorm = snorm
  439. elif reals_ft is double1d_t:
  440. bli_dnorm1v_ex(n, &x[0], incx, &dnorm, NULL, &rntm)
  441. else:
  442. # Impossible --- panic?
  443. pass
  444. return dnorm
  445. cdef double norm_L2(
  446. dim_t n,
  447. reals_ft x, inc_t incx
  448. ) nogil:
  449. cdef double dnorm = 0
  450. cdef float snorm = 0
  451. if reals_ft is floats_t:
  452. bli_snormfv_ex(n, x, incx, &snorm, NULL, &rntm)
  453. dnorm = snorm
  454. elif reals_ft is doubles_t:
  455. bli_dnormfv_ex(n, x, incx, &dnorm, NULL, &rntm)
  456. elif reals_ft is float1d_t:
  457. bli_snormfv_ex(n, &x[0], incx, &snorm, NULL, &rntm)
  458. dnorm = snorm
  459. elif reals_ft is double1d_t:
  460. bli_dnormfv_ex(n, &x[0], incx, &dnorm, NULL, &rntm)
  461. else:
  462. # Impossible --- panic?
  463. pass
  464. return dnorm
  465. cdef double norm_inf(
  466. dim_t n,
  467. reals_ft x, inc_t incx
  468. ) nogil:
  469. cdef double dnorm = 0
  470. cdef float snorm = 0
  471. if reals_ft is floats_t:
  472. bli_snormiv_ex(n, x, incx, &snorm, NULL, &rntm)
  473. dnorm = snorm
  474. elif reals_ft is doubles_t:
  475. bli_dnormiv_ex(n, x, incx, &dnorm, NULL, &rntm)
  476. elif reals_ft is float1d_t:
  477. bli_snormiv_ex(n, &x[0], incx, &snorm, NULL, &rntm)
  478. dnorm = snorm
  479. elif reals_ft is double1d_t:
  480. bli_dnormiv_ex(n, &x[0], incx, &dnorm, NULL, &rntm)
  481. else:
  482. # Impossible --- panic?
  483. pass
  484. return dnorm
  485. cdef double dotv(
  486. conj_t conjx,
  487. conj_t conjy,
  488. dim_t m,
  489. reals_ft x,
  490. reals_ft y,
  491. inc_t incx,
  492. inc_t incy,
  493. ) nogil:
  494. cdef double rho_d = 0.0
  495. cdef float rho_f = 0.0
  496. if reals_ft is floats_t:
  497. bli_sdotv_ex(<blis_conj_t>conjx, <blis_conj_t>conjy, m, x, incx, y, incy, &rho_f, NULL, &rntm)
  498. return rho_f
  499. elif reals_ft is doubles_t:
  500. bli_ddotv_ex(<blis_conj_t>conjx, <blis_conj_t>conjy, m, x, incx, y, incy, &rho_d, NULL, &rntm)
  501. return rho_d
  502. elif reals_ft is float1d_t:
  503. bli_sdotv_ex(<blis_conj_t>conjx, <blis_conj_t>conjy, m, &x[0], incx, &y[0], incy,
  504. &rho_f, NULL, &rntm)
  505. return rho_f
  506. elif reals_ft is double1d_t:
  507. bli_ddotv_ex(<blis_conj_t>conjx, <blis_conj_t>conjy, m, &x[0], incx, &y[0], incy,
  508. &rho_d, NULL, &rntm)
  509. return rho_d
  510. else:
  511. raise ValueError("Unhandled fused type")
  512. cdef void randv(dim_t m, reals_ft x, inc_t incx) nogil:
  513. if reals_ft is floats_t:
  514. bli_srandv_ex(m, x, incx, NULL, &rntm)
  515. elif reals_ft is float1d_t:
  516. bli_srandv_ex(m, &x[0], incx, NULL, &rntm)
  517. if reals_ft is doubles_t:
  518. bli_drandv_ex(m, x, incx, NULL, &rntm)
  519. elif reals_ft is double1d_t:
  520. bli_drandv_ex(m, &x[0], incx, NULL, &rntm)
  521. else:
  522. with gil:
  523. raise ValueError("Unhandled fused type")
  524. cdef void sumsqv(dim_t m, reals_ft x, inc_t incx,
  525. reals_ft scale, reals_ft sumsq) nogil:
  526. if reals_ft is floats_t:
  527. bli_ssumsqv_ex(m, &x[0], incx, scale, sumsq, NULL, &rntm)
  528. elif reals_ft is float1d_t:
  529. bli_ssumsqv_ex(m, &x[0], incx, &scale[0], &sumsq[0], NULL, &rntm)
  530. if reals_ft is doubles_t:
  531. bli_dsumsqv_ex(m, x, incx, scale, sumsq, NULL, &rntm)
  532. elif reals_ft is double1d_t:
  533. bli_dsumsqv_ex(m, &x[0], incx, &scale[0], &sumsq[0], NULL, &rntm)
  534. else:
  535. with gil:
  536. raise ValueError("Unhandled fused type")
  537. @atexit.register
  538. def finalize():
  539. bli_finalize()