You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

186 lines
6.6 KiB

# cython: boundscheck=False
# Copyright ExplsionAI GmbH, released under BSD.
cimport numpy as np
from . cimport cy
from .cy cimport reals1d_ft, reals2d_ft, float1d_t, float2d_t
from .cy cimport const_reals1d_ft, const_reals2d_ft, const_float1d_t, const_float2d_t
from .cy cimport const_double1d_t, const_double2d_t
import numpy
def axpy(const_reals1d_ft A, double scale=1., np.ndarray out=None):
if const_reals1d_ft is const_float1d_t:
if out is None:
out = numpy.zeros((A.shape[0],), dtype='f')
B = <float*>out.data
return out
elif const_reals1d_ft is const_double1d_t:
if out is None:
out = numpy.zeros((A.shape[0],), dtype='d')
B = <double*>out.data
with nogil:
cy.axpyv(cy.NO_CONJUGATE, A.shape[0], scale, &A[0], 1, B, 1)
return out
else:
B = NULL
raise TypeError("Unhandled fused type")
def batch_axpy(reals2d_ft A, reals1d_ft B, np.ndarray out=None):
pass
def ger(const_reals2d_ft A, const_reals1d_ft B, double scale=1., np.ndarray out=None):
if const_reals2d_ft is const_float2d_t and const_reals1d_ft is const_float1d_t:
if out is None:
out = numpy.zeros((A.shape[0], B.shape[0]), dtype='f')
with nogil:
cy.ger(
cy.NO_CONJUGATE, cy.NO_CONJUGATE,
A.shape[0], B.shape[0],
scale,
&A[0,0], 1,
&B[0], 1,
<float*>out.data, out.shape[1], 1)
return out
elif const_reals2d_ft is const_double2d_t and const_reals1d_ft is const_double1d_t:
if out is None:
out = numpy.zeros((A.shape[0], B.shape[0]), dtype='d')
with nogil:
cy.ger(
cy.NO_CONJUGATE, cy.NO_CONJUGATE,
A.shape[0], B.shape[0],
scale,
&A[0,0], 1,
&B[0], 1,
<double*>out.data, out.shape[1], 1)
return out
else:
C = NULL
raise TypeError("Unhandled fused type")
def gemm(const_reals2d_ft A, const_reals2d_ft B,
np.ndarray out=None, bint trans1=False, bint trans2=False,
double alpha=1., double beta=1.):
cdef cy.dim_t nM = A.shape[0] if not trans1 else A.shape[1]
cdef cy.dim_t nK = A.shape[1] if not trans1 else A.shape[0]
cdef cy.dim_t nN = B.shape[1] if not trans2 else B.shape[0]
if const_reals2d_ft is const_float2d_t:
if out is None:
out = numpy.zeros((nM, nN), dtype='f')
C = <float*>out.data
with nogil:
cy.gemm(
cy.TRANSPOSE if trans1 else cy.NO_TRANSPOSE,
cy.TRANSPOSE if trans2 else cy.NO_TRANSPOSE,
nM, nN, nK,
alpha,
&A[0,0], A.shape[1], 1,
&B[0,0], B.shape[1], 1,
beta,
C, out.shape[1], 1)
return out
elif const_reals2d_ft is const_double2d_t:
if out is None:
out = numpy.zeros((A.shape[0], B.shape[1]), dtype='d')
C = <double*>out.data
with nogil:
cy.gemm(
cy.TRANSPOSE if trans1 else cy.NO_TRANSPOSE,
cy.TRANSPOSE if trans2 else cy.NO_TRANSPOSE,
A.shape[0], B.shape[1], A.shape[1],
alpha,
&A[0,0], A.shape[1], 1,
&B[0,0], B.shape[1], 1,
beta,
C, out.shape[1], 1)
return out
else:
C = NULL
raise TypeError("Unhandled fused type")
def gemv(const_reals2d_ft A, const_reals1d_ft B,
bint trans1=False, double alpha=1., double beta=1.,
np.ndarray out=None):
if const_reals1d_ft is const_float1d_t and const_reals2d_ft is const_float2d_t:
if out is None:
out = numpy.zeros((A.shape[0],), dtype='f')
with nogil:
cy.gemv(
cy.TRANSPOSE if trans1 else cy.NO_TRANSPOSE,
cy.NO_CONJUGATE,
A.shape[0], A.shape[1],
alpha,
&A[0,0], A.shape[1], 1,
&B[0], 1,
beta,
<float*>out.data, 1)
return out
elif const_reals1d_ft is const_double1d_t and const_reals2d_ft is const_double2d_t:
if out is None:
out = numpy.zeros((A.shape[0],), dtype='d')
with nogil:
cy.gemv(
cy.TRANSPOSE if trans1 else cy.NO_TRANSPOSE,
cy.NO_CONJUGATE,
A.shape[0], A.shape[1],
alpha,
&A[0,0], A.shape[1], 1,
&B[0], 1,
beta,
<double*>out.data, 1)
return out
else:
raise TypeError("Unhandled fused type")
def dotv(const_reals1d_ft X, const_reals1d_ft Y, bint conjX=False, bint conjY=False):
if X.shape[0] != Y.shape[0]:
msg = "Shape mismatch for blis.dotv: (%d,), (%d,)"
raise ValueError(msg % (X.shape[0], Y.shape[0]))
return cy.dotv(
cy.CONJUGATE if conjX else cy.NO_CONJUGATE,
cy.CONJUGATE if conjY else cy.NO_CONJUGATE,
X.shape[0], &X[0], &Y[0], 1, 1
)
def einsum(todo, A, B, out=None):
if todo == 'a,a->a':
return axpy(A, B, out=out)
elif todo == 'a,b->ab':
return ger(A, B, out=out)
elif todo == 'a,b->ba':
return ger(B, A, out=out)
elif todo == 'ab,a->ab':
return batch_axpy(A, B, out=out)
elif todo == 'ab,a->ba':
return batch_axpy(A, B, trans1=True, out=out)
elif todo == 'ab,b->a':
return gemv(A, B, out=out)
elif todo == 'ab,a->b':
return gemv(A, B, trans1=True, out=out)
# The rule here is, look at the first dimension of the output. That must
# occur in arg1. Set trans1 if it's dimension 2.
# E.g. bc is output, b occurs in ab, so that must be arg1. So we need
# trans1=True, to make ba,ac->bc
elif todo == 'ab,ac->bc':
return gemm(A, B, trans1=True, trans2=False, out=out)
elif todo == 'ab,ac->cb':
return gemm(B, A, out=out, trans1=True, trans2=True)
elif todo == 'ab,bc->ac':
return gemm(A, B, out=out, trans1=False, trans2=False)
elif todo == 'ab,bc->ca':
return gemm(B, A, out=out, trans1=True, trans2=True)
elif todo == 'ab,ca->bc':
return gemm(A, B, out=out, trans1=True, trans2=True)
elif todo == 'ab,ca->cb':
return gemm(B, A, out=out, trans1=False, trans2=False)
elif todo == 'ab,cb->ac':
return gemm(A, B, out=out, trans1=False, trans2=True)
elif todo == 'ab,cb->ca':
return gemm(B, A, out=out, trans1=False, trans2=True)
else:
raise ValueError("Invalid einsum: %s" % todo)