Spaces:
Sleeping
Sleeping
File size: 4,169 Bytes
8aa4f1e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import re
import os
import cupy
import os.path as osp
import torch
@cupy.memoize(for_each_device=True)
def launch_kernel(strFunction, strKernel):
if 'CUDA_HOME' not in os.environ:
os.environ['CUDA_HOME'] = cupy.cuda.get_cuda_path()
# end
# , options=tuple([ '-I ' + os.environ['CUDA_HOME'], '-I ' + os.environ['CUDA_HOME'] + '/include' ])
return cupy.RawKernel(strKernel, strFunction)
def preprocess_kernel(strKernel, objVariables):
path_to_math_helper = osp.join(osp.dirname(osp.abspath(__file__)), 'helper_math.h')
strKernel = '''
#include <{{HELPER_PATH}}>
__device__ __forceinline__ float atomicMin(const float* buffer, float dblValue) {
int intValue = __float_as_int(*buffer);
while (__int_as_float(intValue) > dblValue) {
intValue = atomicCAS((int*) (buffer), intValue, __float_as_int(dblValue));
}
return __int_as_float(intValue);
}
__device__ __forceinline__ float atomicMax(const float* buffer, float dblValue) {
int intValue = __float_as_int(*buffer);
while (__int_as_float(intValue) < dblValue) {
intValue = atomicCAS((int*) (buffer), intValue, __float_as_int(dblValue));
}
return __int_as_float(intValue);
}
'''.replace('{{HELPER_PATH}}', path_to_math_helper) + strKernel
# end
for strVariable in objVariables:
objValue = objVariables[strVariable]
if type(objValue) == int:
strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))
elif type(objValue) == float:
strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))
elif type(objValue) == str:
strKernel = strKernel.replace('{{' + strVariable + '}}', objValue)
# end
# end
while True:
objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel)
if objMatch is None:
break
# end
intArg = int(objMatch.group(2))
strTensor = objMatch.group(4)
intSizes = objVariables[strTensor].size()
strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg] if torch.is_tensor(intSizes[intArg]) == False else intSizes[intArg].item()))
# end
while True:
objMatch = re.search('(STRIDE_)([0-4])(\()([^\)]*)(\))', strKernel)
if objMatch is None:
break
# end
intArg = int(objMatch.group(2))
strTensor = objMatch.group(4)
intStrides = objVariables[strTensor].stride()
strKernel = strKernel.replace(objMatch.group(), str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()))
# end
while True:
objMatch = re.search('(OFFSET_)([0-4])(\()([^\)]+)(\))', strKernel)
if objMatch is None:
break
# end
intArgs = int(objMatch.group(2))
strArgs = objMatch.group(4).split(',')
strTensor = strArgs[0]
intStrides = objVariables[strTensor].stride()
strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')' for intArg in range(intArgs) ]
strKernel = strKernel.replace(objMatch.group(0), '(' + str.join('+', strIndex) + ')')
# end
while True:
objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel)
if objMatch is None:
break
# end
intArgs = int(objMatch.group(2))
strArgs = objMatch.group(4).split(',')
strTensor = strArgs[0]
intStrides = objVariables[strTensor].stride()
strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')' for intArg in range(intArgs) ]
strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']')
# end
return strKernel |