You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

107 lines
2.6 KiB

4 years ago
  1. # Copyright ExplsionAI GmbH, released under BSD.
  2. import numpy
  3. import numpy.random
  4. from .py import gemm, einsum
  5. from timeit import default_timer as timer
  6. numpy.random.seed(0)
  7. def create_data(nO, nI, batch_size):
  8. X = numpy.zeros((batch_size, nI), dtype="f")
  9. X += numpy.random.uniform(-1.0, 1.0, X.shape)
  10. W = numpy.zeros((nO, nI), dtype="f")
  11. W += numpy.random.uniform(-1.0, 1.0, W.shape)
  12. return X, W
  13. def get_numpy_blas():
  14. blas_libs = numpy.__config__.blas_opt_info["libraries"]
  15. return blas_libs[0]
  16. def numpy_gemm(X, W, n=1000):
  17. nO, nI = W.shape
  18. batch_size = X.shape[0]
  19. total = 0.0
  20. y = numpy.zeros((batch_size, nO), dtype="f")
  21. for i in range(n):
  22. numpy.dot(X, W, out=y)
  23. total += y.sum()
  24. y.fill(0)
  25. print("Total:", total)
  26. def blis_gemm(X, W, n=1000):
  27. nO, nI = W.shape
  28. batch_size = X.shape[0]
  29. total = 0.0
  30. y = numpy.zeros((batch_size, nO), dtype="f")
  31. for i in range(n):
  32. gemm(X, W, out=y)
  33. total += y.sum()
  34. y.fill(0.0)
  35. print("Total:", total)
  36. def numpy_einsum(X, W, n=1000):
  37. nO, nI = W.shape
  38. batch_size = X.shape[0]
  39. total = 0.0
  40. y = numpy.zeros((nO, batch_size), dtype="f")
  41. for i in range(n):
  42. numpy.einsum("ab,cb->ca", X, W, out=y)
  43. total += y.sum()
  44. y.fill(0.0)
  45. print("Total:", total)
  46. def blis_einsum(X, W, n=1000):
  47. nO, nI = W.shape
  48. batch_size = X.shape[0]
  49. total = 0.0
  50. y = numpy.zeros((nO, batch_size), dtype="f")
  51. for i in range(n):
  52. einsum("ab,cb->ca", X, W, out=y)
  53. total += y.sum()
  54. y.fill(0.0)
  55. print("Total:", total)
  56. def main(nI=128 * 3, nO=128 * 3, batch_size=2000):
  57. print(
  58. "Setting up data for gemm. 1000 iters, "
  59. "nO={nO} nI={nI} batch_size={batch_size}".format(**locals())
  60. )
  61. numpy_blas = get_numpy_blas()
  62. X1, W1 = create_data(nI, nO, batch_size)
  63. X2 = X1.copy()
  64. W2 = W1.copy()
  65. print("Blis gemm...")
  66. start = timer()
  67. blis_gemm(X2, W2, n=1000)
  68. end = timer()
  69. blis_time = end - start
  70. print("%.2f seconds" % blis_time)
  71. print("Numpy (%s) gemm..." % numpy_blas)
  72. start = timer()
  73. numpy_gemm(X1, W1)
  74. end = timer()
  75. numpy_time = end - start
  76. print("%.2f seconds" % numpy_time)
  77. print("Blis einsum ab,cb->ca")
  78. start = timer()
  79. blis_einsum(X2, W2, n=1000)
  80. end = timer()
  81. blis_time = end - start
  82. print("%.2f seconds" % blis_time)
  83. print("Numpy (%s) einsum ab,cb->ca" % numpy_blas)
  84. start = timer()
  85. numpy_einsum(X2, W2)
  86. end = timer()
  87. numpy_time = end - start
  88. print("%.2f seconds" % numpy_time)
  89. if __name__:
  90. main()