You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

107 lines
2.6 KiB

# Copyright ExplsionAI GmbH, released under BSD.
import numpy
import numpy.random
from .py import gemm, einsum
from timeit import default_timer as timer
numpy.random.seed(0)
def create_data(nO, nI, batch_size):
X = numpy.zeros((batch_size, nI), dtype="f")
X += numpy.random.uniform(-1.0, 1.0, X.shape)
W = numpy.zeros((nO, nI), dtype="f")
W += numpy.random.uniform(-1.0, 1.0, W.shape)
return X, W
def get_numpy_blas():
blas_libs = numpy.__config__.blas_opt_info["libraries"]
return blas_libs[0]
def numpy_gemm(X, W, n=1000):
nO, nI = W.shape
batch_size = X.shape[0]
total = 0.0
y = numpy.zeros((batch_size, nO), dtype="f")
for i in range(n):
numpy.dot(X, W, out=y)
total += y.sum()
y.fill(0)
print("Total:", total)
def blis_gemm(X, W, n=1000):
nO, nI = W.shape
batch_size = X.shape[0]
total = 0.0
y = numpy.zeros((batch_size, nO), dtype="f")
for i in range(n):
gemm(X, W, out=y)
total += y.sum()
y.fill(0.0)
print("Total:", total)
def numpy_einsum(X, W, n=1000):
nO, nI = W.shape
batch_size = X.shape[0]
total = 0.0
y = numpy.zeros((nO, batch_size), dtype="f")
for i in range(n):
numpy.einsum("ab,cb->ca", X, W, out=y)
total += y.sum()
y.fill(0.0)
print("Total:", total)
def blis_einsum(X, W, n=1000):
nO, nI = W.shape
batch_size = X.shape[0]
total = 0.0
y = numpy.zeros((nO, batch_size), dtype="f")
for i in range(n):
einsum("ab,cb->ca", X, W, out=y)
total += y.sum()
y.fill(0.0)
print("Total:", total)
def main(nI=128 * 3, nO=128 * 3, batch_size=2000):
print(
"Setting up data for gemm. 1000 iters, "
"nO={nO} nI={nI} batch_size={batch_size}".format(**locals())
)
numpy_blas = get_numpy_blas()
X1, W1 = create_data(nI, nO, batch_size)
X2 = X1.copy()
W2 = W1.copy()
print("Blis gemm...")
start = timer()
blis_gemm(X2, W2, n=1000)
end = timer()
blis_time = end - start
print("%.2f seconds" % blis_time)
print("Numpy (%s) gemm..." % numpy_blas)
start = timer()
numpy_gemm(X1, W1)
end = timer()
numpy_time = end - start
print("%.2f seconds" % numpy_time)
print("Blis einsum ab,cb->ca")
start = timer()
blis_einsum(X2, W2, n=1000)
end = timer()
blis_time = end - start
print("%.2f seconds" % blis_time)
print("Numpy (%s) einsum ab,cb->ca" % numpy_blas)
start = timer()
numpy_einsum(X2, W2)
end = timer()
numpy_time = end - start
print("%.2f seconds" % numpy_time)
if __name__:
main()