You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

186 lines
6.6 KiB

4 years ago
  1. # cython: boundscheck=False
  2. # Copyright ExplsionAI GmbH, released under BSD.
  3. cimport numpy as np
  4. from . cimport cy
  5. from .cy cimport reals1d_ft, reals2d_ft, float1d_t, float2d_t
  6. from .cy cimport const_reals1d_ft, const_reals2d_ft, const_float1d_t, const_float2d_t
  7. from .cy cimport const_double1d_t, const_double2d_t
  8. import numpy
  9. def axpy(const_reals1d_ft A, double scale=1., np.ndarray out=None):
  10. if const_reals1d_ft is const_float1d_t:
  11. if out is None:
  12. out = numpy.zeros((A.shape[0],), dtype='f')
  13. B = <float*>out.data
  14. return out
  15. elif const_reals1d_ft is const_double1d_t:
  16. if out is None:
  17. out = numpy.zeros((A.shape[0],), dtype='d')
  18. B = <double*>out.data
  19. with nogil:
  20. cy.axpyv(cy.NO_CONJUGATE, A.shape[0], scale, &A[0], 1, B, 1)
  21. return out
  22. else:
  23. B = NULL
  24. raise TypeError("Unhandled fused type")
  25. def batch_axpy(reals2d_ft A, reals1d_ft B, np.ndarray out=None):
  26. pass
  27. def ger(const_reals2d_ft A, const_reals1d_ft B, double scale=1., np.ndarray out=None):
  28. if const_reals2d_ft is const_float2d_t and const_reals1d_ft is const_float1d_t:
  29. if out is None:
  30. out = numpy.zeros((A.shape[0], B.shape[0]), dtype='f')
  31. with nogil:
  32. cy.ger(
  33. cy.NO_CONJUGATE, cy.NO_CONJUGATE,
  34. A.shape[0], B.shape[0],
  35. scale,
  36. &A[0,0], 1,
  37. &B[0], 1,
  38. <float*>out.data, out.shape[1], 1)
  39. return out
  40. elif const_reals2d_ft is const_double2d_t and const_reals1d_ft is const_double1d_t:
  41. if out is None:
  42. out = numpy.zeros((A.shape[0], B.shape[0]), dtype='d')
  43. with nogil:
  44. cy.ger(
  45. cy.NO_CONJUGATE, cy.NO_CONJUGATE,
  46. A.shape[0], B.shape[0],
  47. scale,
  48. &A[0,0], 1,
  49. &B[0], 1,
  50. <double*>out.data, out.shape[1], 1)
  51. return out
  52. else:
  53. C = NULL
  54. raise TypeError("Unhandled fused type")
  55. def gemm(const_reals2d_ft A, const_reals2d_ft B,
  56. np.ndarray out=None, bint trans1=False, bint trans2=False,
  57. double alpha=1., double beta=1.):
  58. cdef cy.dim_t nM = A.shape[0] if not trans1 else A.shape[1]
  59. cdef cy.dim_t nK = A.shape[1] if not trans1 else A.shape[0]
  60. cdef cy.dim_t nN = B.shape[1] if not trans2 else B.shape[0]
  61. if const_reals2d_ft is const_float2d_t:
  62. if out is None:
  63. out = numpy.zeros((nM, nN), dtype='f')
  64. C = <float*>out.data
  65. with nogil:
  66. cy.gemm(
  67. cy.TRANSPOSE if trans1 else cy.NO_TRANSPOSE,
  68. cy.TRANSPOSE if trans2 else cy.NO_TRANSPOSE,
  69. nM, nN, nK,
  70. alpha,
  71. &A[0,0], A.shape[1], 1,
  72. &B[0,0], B.shape[1], 1,
  73. beta,
  74. C, out.shape[1], 1)
  75. return out
  76. elif const_reals2d_ft is const_double2d_t:
  77. if out is None:
  78. out = numpy.zeros((A.shape[0], B.shape[1]), dtype='d')
  79. C = <double*>out.data
  80. with nogil:
  81. cy.gemm(
  82. cy.TRANSPOSE if trans1 else cy.NO_TRANSPOSE,
  83. cy.TRANSPOSE if trans2 else cy.NO_TRANSPOSE,
  84. A.shape[0], B.shape[1], A.shape[1],
  85. alpha,
  86. &A[0,0], A.shape[1], 1,
  87. &B[0,0], B.shape[1], 1,
  88. beta,
  89. C, out.shape[1], 1)
  90. return out
  91. else:
  92. C = NULL
  93. raise TypeError("Unhandled fused type")
  94. def gemv(const_reals2d_ft A, const_reals1d_ft B,
  95. bint trans1=False, double alpha=1., double beta=1.,
  96. np.ndarray out=None):
  97. if const_reals1d_ft is const_float1d_t and const_reals2d_ft is const_float2d_t:
  98. if out is None:
  99. out = numpy.zeros((A.shape[0],), dtype='f')
  100. with nogil:
  101. cy.gemv(
  102. cy.TRANSPOSE if trans1 else cy.NO_TRANSPOSE,
  103. cy.NO_CONJUGATE,
  104. A.shape[0], A.shape[1],
  105. alpha,
  106. &A[0,0], A.shape[1], 1,
  107. &B[0], 1,
  108. beta,
  109. <float*>out.data, 1)
  110. return out
  111. elif const_reals1d_ft is const_double1d_t and const_reals2d_ft is const_double2d_t:
  112. if out is None:
  113. out = numpy.zeros((A.shape[0],), dtype='d')
  114. with nogil:
  115. cy.gemv(
  116. cy.TRANSPOSE if trans1 else cy.NO_TRANSPOSE,
  117. cy.NO_CONJUGATE,
  118. A.shape[0], A.shape[1],
  119. alpha,
  120. &A[0,0], A.shape[1], 1,
  121. &B[0], 1,
  122. beta,
  123. <double*>out.data, 1)
  124. return out
  125. else:
  126. raise TypeError("Unhandled fused type")
  127. def dotv(const_reals1d_ft X, const_reals1d_ft Y, bint conjX=False, bint conjY=False):
  128. if X.shape[0] != Y.shape[0]:
  129. msg = "Shape mismatch for blis.dotv: (%d,), (%d,)"
  130. raise ValueError(msg % (X.shape[0], Y.shape[0]))
  131. return cy.dotv(
  132. cy.CONJUGATE if conjX else cy.NO_CONJUGATE,
  133. cy.CONJUGATE if conjY else cy.NO_CONJUGATE,
  134. X.shape[0], &X[0], &Y[0], 1, 1
  135. )
  136. def einsum(todo, A, B, out=None):
  137. if todo == 'a,a->a':
  138. return axpy(A, B, out=out)
  139. elif todo == 'a,b->ab':
  140. return ger(A, B, out=out)
  141. elif todo == 'a,b->ba':
  142. return ger(B, A, out=out)
  143. elif todo == 'ab,a->ab':
  144. return batch_axpy(A, B, out=out)
  145. elif todo == 'ab,a->ba':
  146. return batch_axpy(A, B, trans1=True, out=out)
  147. elif todo == 'ab,b->a':
  148. return gemv(A, B, out=out)
  149. elif todo == 'ab,a->b':
  150. return gemv(A, B, trans1=True, out=out)
  151. # The rule here is, look at the first dimension of the output. That must
  152. # occur in arg1. Set trans1 if it's dimension 2.
  153. # E.g. bc is output, b occurs in ab, so that must be arg1. So we need
  154. # trans1=True, to make ba,ac->bc
  155. elif todo == 'ab,ac->bc':
  156. return gemm(A, B, trans1=True, trans2=False, out=out)
  157. elif todo == 'ab,ac->cb':
  158. return gemm(B, A, out=out, trans1=True, trans2=True)
  159. elif todo == 'ab,bc->ac':
  160. return gemm(A, B, out=out, trans1=False, trans2=False)
  161. elif todo == 'ab,bc->ca':
  162. return gemm(B, A, out=out, trans1=True, trans2=True)
  163. elif todo == 'ab,ca->bc':
  164. return gemm(A, B, out=out, trans1=True, trans2=True)
  165. elif todo == 'ab,ca->cb':
  166. return gemm(B, A, out=out, trans1=False, trans2=False)
  167. elif todo == 'ab,cb->ac':
  168. return gemm(A, B, out=out, trans1=False, trans2=True)
  169. elif todo == 'ab,cb->ca':
  170. return gemm(B, A, out=out, trans1=False, trans2=True)
  171. else:
  172. raise ValueError("Invalid einsum: %s" % todo)