You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

530 lines
16 KiB

4 years ago
  1. # cython: infer_types=True
  2. # cython: cdivision=True
  3. cimport cython
  4. from libc.stdint cimport int32_t
  5. from libc.string cimport memset, memcpy
  6. from cymem.cymem cimport Pool
  7. from .typedefs cimport weight_t
  8. include "compile_time_constants.pxi"
  9. IF USE_BLAS:
  10. from blis cimport cy as blis
  11. cdef extern from "math.h" nogil:
  12. weight_t exp(weight_t x)
  13. weight_t sqrt(weight_t x)
  14. cdef class Matrix:
  15. cdef readonly Pool mem
  16. cdef weight_t* data
  17. cdef readonly int32_t nr_row
  18. cdef readonly int32_t nr_col
  19. cdef class Vec:
  20. @staticmethod
  21. cdef inline int arg_max(const weight_t* scores, const int n_classes) nogil:
  22. if n_classes == 2:
  23. return 0 if scores[0] > scores[1] else 1
  24. cdef int i
  25. cdef int best = 0
  26. cdef weight_t mode = scores[0]
  27. for i in range(1, n_classes):
  28. if scores[i] > mode:
  29. mode = scores[i]
  30. best = i
  31. return best
  32. @staticmethod
  33. cdef inline weight_t max(const weight_t* x, int32_t nr) nogil:
  34. if nr == 0:
  35. return 0
  36. cdef int i
  37. cdef weight_t mode = x[0]
  38. for i in range(1, nr):
  39. if x[i] > mode:
  40. mode = x[i]
  41. return mode
  42. @staticmethod
  43. cdef inline weight_t sum(const weight_t* vec, int32_t nr) nogil:
  44. cdef int i
  45. cdef weight_t total = 0
  46. for i in range(nr):
  47. total += vec[i]
  48. return total
  49. @staticmethod
  50. cdef inline weight_t norm(const weight_t* vec, int32_t nr) nogil:
  51. cdef weight_t total = 0
  52. for i in range(nr):
  53. total += vec[i] ** 2
  54. return sqrt(total)
  55. @staticmethod
  56. cdef inline void add(weight_t* output, const weight_t* x,
  57. weight_t inc, int32_t nr) nogil:
  58. memcpy(output, x, sizeof(output[0]) * nr)
  59. Vec.add_i(output, inc, nr)
  60. @staticmethod
  61. cdef inline void add_i(weight_t* vec, weight_t inc, int32_t nr) nogil:
  62. cdef int i
  63. for i in range(nr):
  64. vec[i] += inc
  65. @staticmethod
  66. cdef inline void mul(weight_t* output, const weight_t* vec, weight_t scal,
  67. int32_t nr) nogil:
  68. memcpy(output, vec, sizeof(output[0]) * nr)
  69. Vec.mul_i(output, scal, nr)
  70. @staticmethod
  71. cdef inline void mul_i(weight_t* vec, weight_t scal, int32_t nr) nogil:
  72. cdef int i
  73. IF USE_BLAS:
  74. blis.scalv(blis.NO_CONJUGATE, nr, scal, vec, 1)
  75. ELSE:
  76. for i in range(nr):
  77. vec[i] *= scal
  78. @staticmethod
  79. cdef inline void pow(weight_t* output, const weight_t* vec, weight_t scal,
  80. int32_t nr) nogil:
  81. memcpy(output, vec, sizeof(output[0]) * nr)
  82. Vec.pow_i(output, scal, nr)
  83. @staticmethod
  84. cdef inline void pow_i(weight_t* vec, const weight_t scal, int32_t nr) nogil:
  85. cdef int i
  86. for i in range(nr):
  87. vec[i] **= scal
  88. @staticmethod
  89. @cython.cdivision(True)
  90. cdef inline void div(weight_t* output, const weight_t* vec, weight_t scal,
  91. int32_t nr) nogil:
  92. memcpy(output, vec, sizeof(output[0]) * nr)
  93. Vec.div_i(output, scal, nr)
  94. @staticmethod
  95. @cython.cdivision(True)
  96. cdef inline void div_i(weight_t* vec, const weight_t scal, int32_t nr) nogil:
  97. cdef int i
  98. for i in range(nr):
  99. vec[i] /= scal
  100. @staticmethod
  101. cdef inline void exp(weight_t* output, const weight_t* vec, int32_t nr) nogil:
  102. memcpy(output, vec, sizeof(output[0]) * nr)
  103. Vec.exp_i(output, nr)
  104. @staticmethod
  105. cdef inline void exp_i(weight_t* vec, int32_t nr) nogil:
  106. cdef int i
  107. for i in range(nr):
  108. vec[i] = exp(vec[i])
  109. @staticmethod
  110. cdef inline void reciprocal_i(weight_t* vec, int32_t nr) nogil:
  111. cdef int i
  112. for i in range(nr):
  113. vec[i] = 1.0 / vec[i]
  114. cdef class VecVec:
  115. @staticmethod
  116. cdef inline void add(weight_t* output,
  117. const weight_t* x,
  118. const weight_t* y,
  119. weight_t scale,
  120. int32_t nr) nogil:
  121. memcpy(output, x, sizeof(output[0]) * nr)
  122. VecVec.add_i(output, y, scale, nr)
  123. @staticmethod
  124. cdef inline void add_i(weight_t* x,
  125. const weight_t* y,
  126. weight_t scale,
  127. int32_t nr) nogil:
  128. cdef int i
  129. IF USE_BLAS:
  130. blis.axpyv(blis.NO_CONJUGATE, nr, scale, <weight_t*>y, 1, x, 1)
  131. ELSE:
  132. for i in range(nr):
  133. x[i] += y[i] * scale
  134. @staticmethod
  135. cdef inline void batch_add_i(weight_t* x,
  136. const weight_t* y,
  137. weight_t scale,
  138. int32_t nr, int32_t nr_batch) nogil:
  139. # For fixed x, matrix of y
  140. cdef int i, _
  141. for _ in range(nr_batch):
  142. VecVec.add_i(x,
  143. y, scale, nr)
  144. y += nr
  145. @staticmethod
  146. cdef inline void add_pow(weight_t* output,
  147. const weight_t* x, const weight_t* y, weight_t power, int32_t nr) nogil:
  148. memcpy(output, x, sizeof(output[0]) * nr)
  149. VecVec.add_pow_i(output, y, power, nr)
  150. @staticmethod
  151. cdef inline void add_pow_i(weight_t* x,
  152. const weight_t* y, weight_t power, int32_t nr) nogil:
  153. cdef int i
  154. for i in range(nr):
  155. x[i] += y[i] ** power
  156. @staticmethod
  157. cdef inline void mul(weight_t* output,
  158. const weight_t* x, const weight_t* y, int32_t nr) nogil:
  159. memcpy(output, x, sizeof(output[0]) * nr)
  160. VecVec.mul_i(output, y, nr)
  161. @staticmethod
  162. cdef inline void mul_i(weight_t* x,
  163. const weight_t* y, int32_t nr) nogil:
  164. cdef int i
  165. for i in range(nr):
  166. x[i] *= y[i]
  167. @staticmethod
  168. cdef inline weight_t dot(
  169. const weight_t* x, const weight_t* y, int32_t nr) nogil:
  170. cdef int i
  171. cdef weight_t total = 0
  172. for i in range(nr):
  173. total += x[i] * y[i]
  174. return total
  175. @staticmethod
  176. cdef inline int arg_max_if_true(
  177. const weight_t* scores, const int* is_valid, const int n_classes) nogil:
  178. cdef int i
  179. cdef int best = -1
  180. for i in range(n_classes):
  181. if is_valid[i] and (best == -1 or scores[i] > scores[best]):
  182. best = i
  183. return best
  184. @staticmethod
  185. cdef inline int arg_max_if_zero(
  186. const weight_t* scores, const weight_t* costs, const int n_classes) nogil:
  187. cdef int i
  188. cdef int best = -1
  189. for i in range(n_classes):
  190. if costs[i] == 0 and (best == -1 or scores[i] > scores[best]):
  191. best = i
  192. return best
  193. cdef class Mat:
  194. @staticmethod
  195. cdef inline void mean_row(weight_t* Ex,
  196. const weight_t* mat, int32_t nr_row, int32_t nr_col) nogil:
  197. memset(Ex, 0, sizeof(Ex[0]) * nr_col)
  198. for i in range(nr_row):
  199. VecVec.add_i(Ex, &mat[i * nr_col], 1.0, nr_col)
  200. Vec.mul_i(Ex, 1.0 / nr_row, nr_col)
  201. @staticmethod
  202. cdef inline void var_row(weight_t* Vx,
  203. const weight_t* mat, const weight_t* Ex,
  204. int32_t nr_row, int32_t nr_col, weight_t eps) nogil:
  205. # From https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
  206. if nr_row == 0 or nr_col == 0:
  207. return
  208. cdef weight_t sum_, sum2
  209. for i in range(nr_col):
  210. sum_ = 0.0
  211. sum2 = 0.0
  212. for j in range(nr_row):
  213. x = mat[j * nr_col + i]
  214. sum2 += (x - Ex[i]) ** 2
  215. sum_ += x - Ex[i]
  216. Vx[i] = (sum2 - sum_**2 / nr_row) / nr_row
  217. Vx[i] += eps
  218. cdef class MatVec:
  219. @staticmethod
  220. cdef inline void add_i(weight_t* mat,
  221. const weight_t* vec, weight_t scale, int32_t nr_row, int32_t nr_col) nogil:
  222. cdef int i
  223. for i in range(nr_row):
  224. VecVec.add_i(mat + (i * nr_col),
  225. vec, scale, nr_col)
  226. @staticmethod
  227. cdef inline void mul(weight_t* output,
  228. const weight_t* mat,
  229. const weight_t* vec,
  230. int32_t nr_row, int32_t nr_col) nogil:
  231. memcpy(output, mat, sizeof(output[0]) * nr_row * nr_col)
  232. MatVec.mul_i(output, vec, nr_row, nr_col)
  233. @staticmethod
  234. cdef inline void mul_i(weight_t* mat,
  235. const weight_t* vec,
  236. int32_t nr_row, int32_t nr_col) nogil:
  237. cdef int i, row, col
  238. for i in range(nr_row):
  239. row = i * nr_col
  240. for col in range(nr_col):
  241. mat[row + col] *= vec[col]
  242. @staticmethod
  243. cdef inline void dot(weight_t* output,
  244. const weight_t* mat,
  245. const weight_t* vec,
  246. int32_t nr_row, int32_t nr_col) nogil:
  247. cdef int i, row, col
  248. cdef double zero = 0.0
  249. IF USE_BLAS:
  250. blis.gemv(
  251. blis.NO_TRANSPOSE,
  252. blis.NO_CONJUGATE,
  253. nr_row,
  254. nr_col,
  255. 1.0,
  256. <weight_t*>mat, nr_col, 1,
  257. <weight_t*>vec, 1,
  258. 1.0,
  259. output, 1
  260. )
  261. ELSE:
  262. for i in range(nr_row):
  263. row = i * nr_col
  264. for col in range(nr_col):
  265. output[i] += mat[row + col] * vec[col]
  266. @staticmethod
  267. cdef inline void batch_dot(weight_t* output,
  268. const weight_t* mat,
  269. const weight_t* vec,
  270. int32_t nr_row, int32_t nr_col, int32_t nr_batch) nogil:
  271. # Output dim: batch_size * nr_row
  272. # vec dim: batch_size * nr_col
  273. # mat dim: nr_row * nr_col
  274. # batch_size must be M, because can't transpose C
  275. # so nr_row must be N
  276. # so nr_col must be K
  277. # vec: M * K
  278. # mat.T: K * N
  279. # out: M * N
  280. cdef int i, row, col
  281. cdef double one = 1.0
  282. IF USE_BLAS:
  283. blis.gemm(
  284. blis.NO_TRANSPOSE,
  285. blis.TRANSPOSE,
  286. nr_batch,
  287. nr_row,
  288. nr_col,
  289. 1.0,
  290. <weight_t*>vec,
  291. nr_col,
  292. 1,
  293. <weight_t*>mat,
  294. nr_col,
  295. 1,
  296. 1.0,
  297. output,
  298. nr_row,
  299. 1)
  300. ELSE:
  301. for b in range(nr_batch):
  302. MatVec.dot(output,
  303. mat, vec, nr_row, nr_col)
  304. output += nr_row
  305. vec += nr_col
  306. @staticmethod
  307. cdef inline void T_dot(weight_t* output,
  308. const weight_t* mat,
  309. const weight_t* vec,
  310. int32_t nr_row,
  311. int32_t nr_col) nogil:
  312. cdef int i, row, col
  313. cdef double zero = 0.0
  314. cdef double one = 1.0
  315. IF USE_BLAS:
  316. blis.gemv(
  317. blis.TRANSPOSE,
  318. blis.NO_CONJUGATE,
  319. nr_row, nr_col,
  320. 1.0,
  321. <weight_t*>mat, nr_col, 1,
  322. <weight_t*>vec, 1,
  323. 1.0,
  324. output, 1,
  325. )
  326. ELSE:
  327. for row in range(nr_row):
  328. for col in range(nr_col):
  329. output[col] += vec[row] * mat[(row * nr_col) + col]
  330. @staticmethod
  331. cdef inline void batch_T_dot(weight_t* output,
  332. const weight_t* mat,
  333. const weight_t* vec,
  334. int32_t nr_row,
  335. int32_t nr_col,
  336. int32_t nr_batch) nogil:
  337. cdef int _
  338. cdef double one = 1.0
  339. IF USE_BLAS:
  340. # output is (nr_batch, nr_col)
  341. # mat is (nr_row, nr_col)
  342. # vec is (nr_batch, nr_row)
  343. # Output defined as (M, N)
  344. # So
  345. # nr_batch = M
  346. # nr_col = N
  347. # nr_row = K
  348. #
  349. # vec: M * K
  350. # mat: K * N
  351. # out: M * N
  352. blis.gemm(
  353. blis.NO_TRANSPOSE,
  354. blis.NO_TRANSPOSE,
  355. nr_batch,
  356. nr_col,
  357. nr_row,
  358. 1.0,
  359. <weight_t*>vec,
  360. nr_row,
  361. 1,
  362. <weight_t*>mat,
  363. nr_col,
  364. 1,
  365. 1.0,
  366. output,
  367. nr_col,
  368. 1)
  369. ELSE:
  370. for _ in range(nr_batch):
  371. MatVec.T_dot(output,
  372. mat, vec, nr_row, nr_col)
  373. output += nr_col
  374. vec += nr_row
  375. cdef class MatMat:
  376. @staticmethod
  377. cdef inline void add(weight_t* output,
  378. const weight_t* x,
  379. const weight_t* y,
  380. int32_t nr_row, int32_t nr_col) nogil:
  381. memcpy(output, x, sizeof(output[0]) * nr_row * nr_col)
  382. MatMat.add_i(output, y, nr_row, nr_col)
  383. @staticmethod
  384. cdef inline void add_i(weight_t* x,
  385. const weight_t* y,
  386. int32_t nr_row, int32_t nr_col) nogil:
  387. cdef int i, row, col
  388. for i in range(nr_row):
  389. row = i * nr_col
  390. for col in range(nr_col):
  391. x[row + col] += y[row + col]
  392. @staticmethod
  393. cdef inline void mul(weight_t* output,
  394. const weight_t* x,
  395. const weight_t* y,
  396. int32_t nr_row, int32_t nr_col) nogil:
  397. memcpy(output, x, sizeof(output[0]) * nr_row * nr_col)
  398. MatMat.mul_i(output, y, nr_row, nr_col)
  399. @staticmethod
  400. cdef inline void mul_i(weight_t* x,
  401. const weight_t* y,
  402. int32_t nr_row, int32_t nr_col) nogil:
  403. cdef int i, row, col
  404. for i in range(nr_row):
  405. row = i * nr_col
  406. for col in range(nr_col):
  407. x[row + col] *= y[row + col]
  408. @staticmethod
  409. cdef inline void add_outer_i(weight_t* mat,
  410. const weight_t* x,
  411. const weight_t* y,
  412. int32_t nr_row,
  413. int32_t nr_col) nogil:
  414. cdef int i, j, row
  415. cdef double one = 1.0
  416. IF USE_BLAS:
  417. blis.ger(
  418. blis.NO_CONJUGATE, blis.NO_CONJUGATE,
  419. nr_row, nr_col,
  420. 1.0,
  421. <weight_t*>x, 1,
  422. <weight_t*>y, 1,
  423. mat, nr_col, 1
  424. )
  425. ELSE:
  426. for i in range(nr_row):
  427. row = i * nr_col
  428. for j in range(nr_col):
  429. mat[row + j] += x[i] * y[j]
  430. @staticmethod
  431. cdef inline void batch_add_outer_i(weight_t* output,
  432. const weight_t* x,
  433. const weight_t* y,
  434. int32_t nr_row,
  435. int32_t nr_col,
  436. int32_t nr_batch) nogil:
  437. # Output dim: nr_row * nr_col
  438. # x dim: batch_size * nr_row
  439. # y dim: batch_size * nr_col
  440. #
  441. # Output is M*N (can't transpose)
  442. # nr_row = M
  443. # nr_col = N
  444. # batch_size = K
  445. # x.T: M * K
  446. # y: K * N
  447. # out: M * N
  448. cdef double one = 1.0
  449. IF USE_BLAS:
  450. blis.gemm(
  451. blis.TRANSPOSE,
  452. blis.NO_TRANSPOSE,
  453. nr_row,
  454. nr_col,
  455. nr_batch,
  456. 1.0,
  457. <weight_t*>x,
  458. nr_row,
  459. 1,
  460. <weight_t*>y,
  461. nr_col,
  462. 1,
  463. 1.0,
  464. output,
  465. nr_col,
  466. 1)
  467. ELSE:
  468. for _ in range(nr_batch):
  469. for i in range(nr_row):
  470. row = i * nr_col
  471. for j in range(nr_col):
  472. output[row + j] += x[i] * y[j]
  473. x += nr_row
  474. y += nr_col