You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

182 lines
2.9 KiB

4 years ago
  1. # Copyright ExplsionAI GmbH, released under BSD.
  2. from cython cimport view
  3. from libc.stdint cimport int64_t
  4. ctypedef float[::1] float1d_t
  5. ctypedef double[::1] double1d_t
  6. ctypedef float[:, ::1] float2d_t
  7. ctypedef double[:, ::1] double2d_t
  8. ctypedef float* floats_t
  9. ctypedef double* doubles_t
  10. ctypedef const float[::1] const_float1d_t
  11. ctypedef const double[::1] const_double1d_t
  12. ctypedef const float[:, ::1] const_float2d_t
  13. ctypedef const double[:, ::1] const_double2d_t
  14. ctypedef const float* const_floats_t
  15. ctypedef const double* const_doubles_t
  16. cdef fused reals_ft:
  17. floats_t
  18. doubles_t
  19. float1d_t
  20. double1d_t
  21. cdef fused const_reals_ft:
  22. const_floats_t
  23. const_doubles_t
  24. const_float1d_t
  25. const_double1d_t
  26. cdef fused reals1d_ft:
  27. float1d_t
  28. double1d_t
  29. cdef fused const_reals1d_ft:
  30. const_float1d_t
  31. const_double1d_t
  32. cdef fused reals2d_ft:
  33. float2d_t
  34. double2d_t
  35. cdef fused const_reals2d_ft:
  36. const_float2d_t
  37. const_double2d_t
  38. cdef fused real_ft:
  39. float
  40. double
  41. ctypedef int64_t dim_t
  42. ctypedef int64_t inc_t
  43. ctypedef int64_t doff_t
  44. # Sucks to set these from magic numbers, but it's better than dragging
  45. # the header into our header.
  46. # We get some piece of mind from checking the values on init.
  47. cpdef enum trans_t:
  48. NO_TRANSPOSE = 0
  49. TRANSPOSE = 8
  50. CONJ_NO_TRANSPOSE = 16
  51. CONJ_TRANSPOSE = 24
  52. cpdef enum conj_t:
  53. NO_CONJUGATE = 0
  54. CONJUGATE = 16
  55. cpdef enum side_t:
  56. LEFT = 0
  57. RIGHT = 1
  58. cpdef enum uplo_t:
  59. LOWER = 192
  60. UPPER = 96
  61. DENSE = 224
  62. cpdef enum diag_t:
  63. NONUNIT_DIAG = 0
  64. UNIT_DIAG = 256
  65. cdef void gemm(
  66. trans_t transa,
  67. trans_t transb,
  68. dim_t m,
  69. dim_t n,
  70. dim_t k,
  71. double alpha,
  72. reals_ft a, inc_t rsa, inc_t csa,
  73. reals_ft b, inc_t rsb, inc_t csb,
  74. double beta,
  75. reals_ft c, inc_t rsc, inc_t csc,
  76. ) nogil
  77. cdef void ger(
  78. conj_t conjx,
  79. conj_t conjy,
  80. dim_t m,
  81. dim_t n,
  82. double alpha,
  83. reals_ft x, inc_t incx,
  84. reals_ft y, inc_t incy,
  85. reals_ft a, inc_t rsa, inc_t csa
  86. ) nogil
  87. cdef void gemv(
  88. trans_t transa,
  89. conj_t conjx,
  90. dim_t m,
  91. dim_t n,
  92. real_ft alpha,
  93. reals_ft a, inc_t rsa, inc_t csa,
  94. reals_ft x, inc_t incx,
  95. real_ft beta,
  96. reals_ft y, inc_t incy
  97. ) nogil
  98. cdef void axpyv(
  99. conj_t conjx,
  100. dim_t m,
  101. real_ft alpha,
  102. reals_ft x, inc_t incx,
  103. reals_ft y, inc_t incy
  104. ) nogil
  105. cdef void scalv(
  106. conj_t conjalpha,
  107. dim_t m,
  108. real_ft alpha,
  109. reals_ft x, inc_t incx
  110. ) nogil
  111. cdef double dotv(
  112. conj_t conjx,
  113. conj_t conjy,
  114. dim_t m,
  115. reals_ft x,
  116. reals_ft y,
  117. inc_t incx,
  118. inc_t incy,
  119. ) nogil
  120. cdef double norm_L1(
  121. dim_t n,
  122. reals_ft x, inc_t incx
  123. ) nogil
  124. cdef double norm_L2(
  125. dim_t n,
  126. reals_ft x, inc_t incx
  127. ) nogil
  128. cdef double norm_inf(
  129. dim_t n,
  130. reals_ft x, inc_t incx
  131. ) nogil
  132. cdef void randv(
  133. dim_t m,
  134. reals_ft x, inc_t incx
  135. ) nogil