hqjenny-centrifuge/examples/gemx_tl/python/gemm.py

135 lines
4.5 KiB
Python

from ctypes import *
import timeit
import numpy as np
import sys
class GEMMManager:
def __init__(self, libFile):
self._handle = None
self._lib = cdll.LoadLibrary(libFile)
self._lib.DestroyGEMMHost.argtypes = [c_void_p]
self._lib.MakeGEMMHost.argtypes = [c_char_p, c_char_p]
self._lib.MakeGEMMHost.restype = c_void_p
self._lib.SendToFPGAShrt_GEMM.argtypes = [c_void_p, np.ctypeslib.ndpointer(c_short, flags="C_CONTIGUOUS"), c_ulonglong, c_bool]
self._lib.SendToFPGAInt_GEMM.argtypes = [c_void_p, np.ctypeslib.ndpointer(c_int, flags="C_CONTIGUOUS"), c_ulonglong, c_bool]
self._lib.AddGEMMOp.argtypes = [c_void_p,
np.ctypeslib.ndpointer(c_short, flags="C_CONTIGUOUS"),
np.ctypeslib.ndpointer(c_short, flags="C_CONTIGUOUS"),
np.ctypeslib.ndpointer(c_short, flags="C_CONTIGUOUS"),
np.ctypeslib.ndpointer(c_int, flags="C_CONTIGUOUS"),
c_uint, c_uint, c_uint, c_int, c_int]
self._lib.AddGEMMOp.restype = c_bool
self._lib.Execute_GEMM.argtypes = [c_void_p]
self._lib.GetFromFPGA_GEMM.argtypes = [c_void_p, np.ctypeslib.ndpointer(c_short, flags="C_CONTIGUOUS"), c_bool]
self._lib.GetFromFPGA_GEMM.restype = c_void_p
self._lib.Wait_GEMM.argtypes = [c_void_p]
def createHandle (self, xclbin, kernel, numHandles):
b_xclbin = xclbin.encode('utf-8')
b_kernel = kernel.encode('utf-8')
self._handle = self._lib.MakeGEMMHost(b_xclbin, b_kernel)
def closeHandle(self):
if not self._handle:
return
#for h in self._handles:
self._lib.DestroyGEMMHost(self._handle)
def addGEMMOp(self, A, B, C, bias, postScale, postShift):
ret = self._lib.AddGEMMOp(self._handle, A,B, C, bias, c_uint(A.shape[0]), c_uint( A.shape[1] ), c_uint( B.shape[1]), c_int(postScale), c_int(postShift))
if ret == False:
sys.exit()
def execute(self):
self._lib.Execute_GEMM(self._handle)
def wait(self):
self._lib.Wait_GEMM(self._handle)
def sendMat ( self, A,sync_send = False):
if A.dtype == np.int32:
self._lib.SendToFPGAInt_GEMM( self._handle, A, c_ulonglong(A.shape[0] *A.shape[1]), sync_send )
elif A.dtype == np.int16:
self._lib.SendToFPGAShrt_GEMM( self._handle, A, c_ulonglong(A.shape[0] *A.shape[1]), sync_send )
else:
print ("sendMat: ", A.dtype, " type not supported")
sys.exit()
#ctypes.data_as(POINTER(c_float))
def getMat(self, A, sync_get = True):
self._lib.GetFromFPGA_GEMM( self._handle, A, sync_get )
def matmul_addbias (self, A, B, bias, sendA = True, sendB = True, sendBias = True):
if len (A.shape) != 2:
print ("A matrix not 2-D")
sys.exit()
if len (B.shape) != 2:
print ("B matrix not 2-D")
sys.exit()
if A.shape[1] != B.shape[0]:
print ("Can't perform GEMM as A dim[1] %d and B dim[0] %d are different" % (A.shape[1], B.shape[0]) )
sys.exit()
if sendA == True:
sendMat(A)
if sendB == True:
sendMat(B)
if sendBias == True:
sendMat(bias)
C_fpga = np.zeros( (A.shape[0], B.shape[1]), dtype=np.int16)
sendMat(C_fpga)
sendMat(bias)
addGEMMOp ( A, B, C_fpga, bias, 1, 0)
execute()
getMat(C_fpga)
return C_fpga
_gemxManager = None
def sendMat ( A,sync_send=False):
_gemxManager.sendMat(A,sync_send)
def getMat (A, sync_get = True):
return _gemxManager.getMat(A, sync_get)
def addGEMMOp( A,B,C, bias, postScale, postShift):
_gemxManager.addGEMMOp(A, B, C, bias, postScale, postShift)
def execute():
_gemxManager.execute()
def wait():
_gemxManager.wait()
def matmul ( A, B, SendA = True, SendB = True):
bias = np.zeros ( (A.shape[0], B.shape[1]), dtype=np.int32)
return _gemxManager.matmul_addbias(A, B, bias, SendA, SendB,True)
def matmul_addbias ( A, B, bias, SendA = True, SendB = True, SendBias = True):
return _gemxManager.matmul_addbias(A, B, bias, SendA, SendB, SendBias)
def createManager ( libFile ):
return GEMMManager(libFile)
def createHandle(xclbin, kernel, libFile, numHandles=1):
global _gemxManager
if not _gemxManager:
_gemxManager = GEMMManager(libFile)
return _gemxManager.createHandle(xclbin, kernel, numHandles)
def closeHandle():
return _gemxManager.closeHandle()