187 lines
6.6 KiB
Cython
187 lines
6.6 KiB
Cython
|
# cython: boundscheck=False
|
||
|
# Copyright ExplsionAI GmbH, released under BSD.
|
||
|
cimport numpy as np
|
||
|
from . cimport cy
|
||
|
from .cy cimport reals1d_ft, reals2d_ft, float1d_t, float2d_t
|
||
|
from .cy cimport const_reals1d_ft, const_reals2d_ft, const_float1d_t, const_float2d_t
|
||
|
from .cy cimport const_double1d_t, const_double2d_t
|
||
|
|
||
|
import numpy
|
||
|
|
||
|
|
||
|
def axpy(const_reals1d_ft A, double scale=1., np.ndarray out=None):
|
||
|
if const_reals1d_ft is const_float1d_t:
|
||
|
if out is None:
|
||
|
out = numpy.zeros((A.shape[0],), dtype='f')
|
||
|
B = <float*>out.data
|
||
|
return out
|
||
|
elif const_reals1d_ft is const_double1d_t:
|
||
|
if out is None:
|
||
|
out = numpy.zeros((A.shape[0],), dtype='d')
|
||
|
B = <double*>out.data
|
||
|
with nogil:
|
||
|
cy.axpyv(cy.NO_CONJUGATE, A.shape[0], scale, &A[0], 1, B, 1)
|
||
|
return out
|
||
|
else:
|
||
|
B = NULL
|
||
|
raise TypeError("Unhandled fused type")
|
||
|
|
||
|
|
||
|
def batch_axpy(reals2d_ft A, reals1d_ft B, np.ndarray out=None):
|
||
|
pass
|
||
|
|
||
|
|
||
|
def ger(const_reals2d_ft A, const_reals1d_ft B, double scale=1., np.ndarray out=None):
|
||
|
if const_reals2d_ft is const_float2d_t and const_reals1d_ft is const_float1d_t:
|
||
|
if out is None:
|
||
|
out = numpy.zeros((A.shape[0], B.shape[0]), dtype='f')
|
||
|
with nogil:
|
||
|
cy.ger(
|
||
|
cy.NO_CONJUGATE, cy.NO_CONJUGATE,
|
||
|
A.shape[0], B.shape[0],
|
||
|
scale,
|
||
|
&A[0,0], 1,
|
||
|
&B[0], 1,
|
||
|
<float*>out.data, out.shape[1], 1)
|
||
|
return out
|
||
|
elif const_reals2d_ft is const_double2d_t and const_reals1d_ft is const_double1d_t:
|
||
|
if out is None:
|
||
|
out = numpy.zeros((A.shape[0], B.shape[0]), dtype='d')
|
||
|
with nogil:
|
||
|
cy.ger(
|
||
|
cy.NO_CONJUGATE, cy.NO_CONJUGATE,
|
||
|
A.shape[0], B.shape[0],
|
||
|
scale,
|
||
|
&A[0,0], 1,
|
||
|
&B[0], 1,
|
||
|
<double*>out.data, out.shape[1], 1)
|
||
|
return out
|
||
|
else:
|
||
|
C = NULL
|
||
|
raise TypeError("Unhandled fused type")
|
||
|
|
||
|
|
||
|
def gemm(const_reals2d_ft A, const_reals2d_ft B,
|
||
|
np.ndarray out=None, bint trans1=False, bint trans2=False,
|
||
|
double alpha=1., double beta=1.):
|
||
|
cdef cy.dim_t nM = A.shape[0] if not trans1 else A.shape[1]
|
||
|
cdef cy.dim_t nK = A.shape[1] if not trans1 else A.shape[0]
|
||
|
cdef cy.dim_t nN = B.shape[1] if not trans2 else B.shape[0]
|
||
|
if const_reals2d_ft is const_float2d_t:
|
||
|
if out is None:
|
||
|
out = numpy.zeros((nM, nN), dtype='f')
|
||
|
C = <float*>out.data
|
||
|
with nogil:
|
||
|
cy.gemm(
|
||
|
cy.TRANSPOSE if trans1 else cy.NO_TRANSPOSE,
|
||
|
cy.TRANSPOSE if trans2 else cy.NO_TRANSPOSE,
|
||
|
nM, nN, nK,
|
||
|
alpha,
|
||
|
&A[0,0], A.shape[1], 1,
|
||
|
&B[0,0], B.shape[1], 1,
|
||
|
beta,
|
||
|
C, out.shape[1], 1)
|
||
|
return out
|
||
|
elif const_reals2d_ft is const_double2d_t:
|
||
|
if out is None:
|
||
|
out = numpy.zeros((A.shape[0], B.shape[1]), dtype='d')
|
||
|
C = <double*>out.data
|
||
|
with nogil:
|
||
|
cy.gemm(
|
||
|
cy.TRANSPOSE if trans1 else cy.NO_TRANSPOSE,
|
||
|
cy.TRANSPOSE if trans2 else cy.NO_TRANSPOSE,
|
||
|
A.shape[0], B.shape[1], A.shape[1],
|
||
|
alpha,
|
||
|
&A[0,0], A.shape[1], 1,
|
||
|
&B[0,0], B.shape[1], 1,
|
||
|
beta,
|
||
|
C, out.shape[1], 1)
|
||
|
return out
|
||
|
else:
|
||
|
C = NULL
|
||
|
raise TypeError("Unhandled fused type")
|
||
|
|
||
|
def gemv(const_reals2d_ft A, const_reals1d_ft B,
|
||
|
bint trans1=False, double alpha=1., double beta=1.,
|
||
|
np.ndarray out=None):
|
||
|
if const_reals1d_ft is const_float1d_t and const_reals2d_ft is const_float2d_t:
|
||
|
if out is None:
|
||
|
out = numpy.zeros((A.shape[0],), dtype='f')
|
||
|
with nogil:
|
||
|
cy.gemv(
|
||
|
cy.TRANSPOSE if trans1 else cy.NO_TRANSPOSE,
|
||
|
cy.NO_CONJUGATE,
|
||
|
A.shape[0], A.shape[1],
|
||
|
alpha,
|
||
|
&A[0,0], A.shape[1], 1,
|
||
|
&B[0], 1,
|
||
|
beta,
|
||
|
<float*>out.data, 1)
|
||
|
return out
|
||
|
elif const_reals1d_ft is const_double1d_t and const_reals2d_ft is const_double2d_t:
|
||
|
if out is None:
|
||
|
out = numpy.zeros((A.shape[0],), dtype='d')
|
||
|
with nogil:
|
||
|
cy.gemv(
|
||
|
cy.TRANSPOSE if trans1 else cy.NO_TRANSPOSE,
|
||
|
cy.NO_CONJUGATE,
|
||
|
A.shape[0], A.shape[1],
|
||
|
alpha,
|
||
|
&A[0,0], A.shape[1], 1,
|
||
|
&B[0], 1,
|
||
|
beta,
|
||
|
<double*>out.data, 1)
|
||
|
return out
|
||
|
else:
|
||
|
raise TypeError("Unhandled fused type")
|
||
|
|
||
|
|
||
|
def dotv(const_reals1d_ft X, const_reals1d_ft Y, bint conjX=False, bint conjY=False):
|
||
|
if X.shape[0] != Y.shape[0]:
|
||
|
msg = "Shape mismatch for blis.dotv: (%d,), (%d,)"
|
||
|
raise ValueError(msg % (X.shape[0], Y.shape[0]))
|
||
|
return cy.dotv(
|
||
|
cy.CONJUGATE if conjX else cy.NO_CONJUGATE,
|
||
|
cy.CONJUGATE if conjY else cy.NO_CONJUGATE,
|
||
|
X.shape[0], &X[0], &Y[0], 1, 1
|
||
|
)
|
||
|
|
||
|
|
||
|
def einsum(todo, A, B, out=None):
|
||
|
if todo == 'a,a->a':
|
||
|
return axpy(A, B, out=out)
|
||
|
elif todo == 'a,b->ab':
|
||
|
return ger(A, B, out=out)
|
||
|
elif todo == 'a,b->ba':
|
||
|
return ger(B, A, out=out)
|
||
|
elif todo == 'ab,a->ab':
|
||
|
return batch_axpy(A, B, out=out)
|
||
|
elif todo == 'ab,a->ba':
|
||
|
return batch_axpy(A, B, trans1=True, out=out)
|
||
|
elif todo == 'ab,b->a':
|
||
|
return gemv(A, B, out=out)
|
||
|
elif todo == 'ab,a->b':
|
||
|
return gemv(A, B, trans1=True, out=out)
|
||
|
# The rule here is, look at the first dimension of the output. That must
|
||
|
# occur in arg1. Set trans1 if it's dimension 2.
|
||
|
# E.g. bc is output, b occurs in ab, so that must be arg1. So we need
|
||
|
# trans1=True, to make ba,ac->bc
|
||
|
elif todo == 'ab,ac->bc':
|
||
|
return gemm(A, B, trans1=True, trans2=False, out=out)
|
||
|
elif todo == 'ab,ac->cb':
|
||
|
return gemm(B, A, out=out, trans1=True, trans2=True)
|
||
|
elif todo == 'ab,bc->ac':
|
||
|
return gemm(A, B, out=out, trans1=False, trans2=False)
|
||
|
elif todo == 'ab,bc->ca':
|
||
|
return gemm(B, A, out=out, trans1=True, trans2=True)
|
||
|
elif todo == 'ab,ca->bc':
|
||
|
return gemm(A, B, out=out, trans1=True, trans2=True)
|
||
|
elif todo == 'ab,ca->cb':
|
||
|
return gemm(B, A, out=out, trans1=False, trans2=False)
|
||
|
elif todo == 'ab,cb->ac':
|
||
|
return gemm(A, B, out=out, trans1=False, trans2=True)
|
||
|
elif todo == 'ab,cb->ca':
|
||
|
return gemm(B, A, out=out, trans1=False, trans2=True)
|
||
|
else:
|
||
|
raise ValueError("Invalid einsum: %s" % todo)
|