|
|
- # Copyright ExplsionAI GmbH, released under BSD.
- import numpy
- import numpy.random
- from .py import gemm, einsum
- from timeit import default_timer as timer
-
- numpy.random.seed(0)
-
-
- def create_data(nO, nI, batch_size):
- X = numpy.zeros((batch_size, nI), dtype="f")
- X += numpy.random.uniform(-1.0, 1.0, X.shape)
- W = numpy.zeros((nO, nI), dtype="f")
- W += numpy.random.uniform(-1.0, 1.0, W.shape)
- return X, W
-
-
- def get_numpy_blas():
- blas_libs = numpy.__config__.blas_opt_info["libraries"]
- return blas_libs[0]
-
-
- def numpy_gemm(X, W, n=1000):
- nO, nI = W.shape
- batch_size = X.shape[0]
- total = 0.0
- y = numpy.zeros((batch_size, nO), dtype="f")
- for i in range(n):
- numpy.dot(X, W, out=y)
- total += y.sum()
- y.fill(0)
- print("Total:", total)
-
-
- def blis_gemm(X, W, n=1000):
- nO, nI = W.shape
- batch_size = X.shape[0]
- total = 0.0
- y = numpy.zeros((batch_size, nO), dtype="f")
- for i in range(n):
- gemm(X, W, out=y)
- total += y.sum()
- y.fill(0.0)
- print("Total:", total)
-
-
- def numpy_einsum(X, W, n=1000):
- nO, nI = W.shape
- batch_size = X.shape[0]
- total = 0.0
- y = numpy.zeros((nO, batch_size), dtype="f")
- for i in range(n):
- numpy.einsum("ab,cb->ca", X, W, out=y)
- total += y.sum()
- y.fill(0.0)
- print("Total:", total)
-
-
- def blis_einsum(X, W, n=1000):
- nO, nI = W.shape
- batch_size = X.shape[0]
- total = 0.0
- y = numpy.zeros((nO, batch_size), dtype="f")
- for i in range(n):
- einsum("ab,cb->ca", X, W, out=y)
- total += y.sum()
- y.fill(0.0)
- print("Total:", total)
-
-
- def main(nI=128 * 3, nO=128 * 3, batch_size=2000):
- print(
- "Setting up data for gemm. 1000 iters, "
- "nO={nO} nI={nI} batch_size={batch_size}".format(**locals())
- )
- numpy_blas = get_numpy_blas()
- X1, W1 = create_data(nI, nO, batch_size)
- X2 = X1.copy()
- W2 = W1.copy()
- print("Blis gemm...")
- start = timer()
- blis_gemm(X2, W2, n=1000)
- end = timer()
- blis_time = end - start
- print("%.2f seconds" % blis_time)
- print("Numpy (%s) gemm..." % numpy_blas)
- start = timer()
- numpy_gemm(X1, W1)
- end = timer()
- numpy_time = end - start
- print("%.2f seconds" % numpy_time)
- print("Blis einsum ab,cb->ca")
- start = timer()
- blis_einsum(X2, W2, n=1000)
- end = timer()
- blis_time = end - start
- print("%.2f seconds" % blis_time)
- print("Numpy (%s) einsum ab,cb->ca" % numpy_blas)
- start = timer()
- numpy_einsum(X2, W2)
- end = timer()
- numpy_time = end - start
- print("%.2f seconds" % numpy_time)
-
-
- if __name__:
- main()
|