diff --git a/setup.py b/setup.py index 7b189ab..97a4f38 100644 --- a/setup.py +++ b/setup.py @@ -11,11 +11,23 @@ if torch.cuda.is_available(): assert torch.matmul(torch.ones(2097153,2).cuda(),torch.ones(2,2).cuda()).min().item()==2, 'Please upgrade from CUDA 9.0 to CUDA 10.0+' +use_halide = False + this_dir = os.path.dirname(os.path.realpath(__file__)) torch_dir = os.path.dirname(torch.__file__) conda_include_dir = '/'.join(torch_dir.split('/')[:-4]) + '/include' +extra = {'cxx': ['-std=c++14', '-fopenmp', '-mavx', '-march=native', '-mtune=native'], 'nvcc': ['-std=c++14', '-Xcompiler', '-fopenmp']} + +extra_link = [] +include_dirs = [conda_include_dir, this_dir+'/sparseconvnet/SCN/'] +build_dirs = ['sparseconvnet/SCN/pybind.cpp', 'sparseconvnet/SCN/sparseconvnet_cpu.cpp'] -extra = {'cxx': ['-std=c++11', '-fopenmp'], 'nvcc': ['-std=c++11', '-Xcompiler', '-fopenmp']} +if use_halide: + halide_include_dir = os.environ['CONDA_PREFIX']+'/include' + halide_mul = 'sparseconvnet/SCN/Halide/mul.cpp' + extra_link = ['-lHalide', '-ldl', '-lz'] + include_dirs.append(halide_include_dir) + build_dirs.append(halide_mul) setup( name='sparseconvnet', @@ -27,14 +39,14 @@ packages=['sparseconvnet','sparseconvnet.SCN'], ext_modules=[ CUDAExtension('sparseconvnet.SCN', - [ - 'sparseconvnet/SCN/cuda.cu', 'sparseconvnet/SCN/sparseconvnet_cuda.cpp', 'sparseconvnet/SCN/pybind.cpp'], + ['sparseconvnet/SCN/cuda.cu', 'sparseconvnet/SCN/sparseconvnet_cuda.cpp', 'sparseconvnet/SCN/pybind.cpp'], include_dirs=[conda_include_dir, this_dir+'/sparseconvnet/SCN/'], extra_compile_args=extra) - if torch.cuda.is_available() else + if torch.cuda.is_available() else CppExtension('sparseconvnet.SCN', - ['sparseconvnet/SCN/pybind.cpp', 'sparseconvnet/SCN/sparseconvnet_cpu.cpp'], - include_dirs=[conda_include_dir, this_dir+'/sparseconvnet/SCN/'], + build_dirs, + include_dirs=include_dirs, + extra_link_args=extra_link, extra_compile_args=extra['cxx'])], cmdclass={'build_ext': BuildExtension}, zip_safe=False, diff --git a/sparseconvnet/SCN/CPU/Convolution.cpp b/sparseconvnet/SCN/CPU/Convolution.cpp index b97a197..32c5e66 100644 --- a/sparseconvnet/SCN/CPU/Convolution.cpp +++ b/sparseconvnet/SCN/CPU/Convolution.cpp @@ -25,6 +25,36 @@ at::Tensor rule_index_select(at::Tensor &src, Int nRules, const Int *rules, // groups x rows x planes -> rows x groups x planes +#if defined(__AVX__) +#include +template +void rule_index_add_(at::Tensor &target, at::Tensor &src, int nRules, + const int *rules, int groups) { + auto planes = target.size(1) / groups; + auto s_ptr = src.data(); + auto t_ptr = target.data(); + +#pragma omp parallel for + for (Int i = 0; i < nRules; ++i) { + for (Int g = 0; g < groups; ++g) { + auto s = s_ptr + (g * nRules + i) * planes; + auto t = t_ptr + (rules[2 * i] * groups + g) * planes; + int rem = planes % 8; + int j = 0; + + for (; j < planes - rem; j += 8) { + auto tar = _mm256_loadu_ps(t + j); + auto sour = _mm256_loadu_ps(s + j); + auto res = _mm256_add_ps(tar, sour); + _mm256_storeu_ps(t + j, res); + } + + for (Int r = 0; r < rem; ++r) + t[j + r] += s[j + r]; + } + } +} +#else template void rule_index_add_(at::Tensor &target, at::Tensor &src, Int nRules, const Int *rules, Int groups) { @@ -41,6 +71,7 @@ void rule_index_add_(at::Tensor &target, at::Tensor &src, Int nRules, } } } +#endif template double cpu_Convolution_updateOutput( diff --git a/sparseconvnet/SCN/Halide/mul.cpp b/sparseconvnet/SCN/Halide/mul.cpp new file mode 100644 index 0000000..4e1521a --- /dev/null +++ b/sparseconvnet/SCN/Halide/mul.cpp @@ -0,0 +1,271 @@ +#define _GLIBCXX_USE_CXX11_ABI 1 +#define HL_PERMIT_FAILED_UNROLL 1 + +#include "mul.hpp" + +#include "Halide.h" +#include "HalideBuffer.h" + +#include + +/* Estimates for some of the Halide parameters */ +static const int maxHalideRow = 1000000; +static const int featureCount = 32; +static const int activeRows = 60000; +static const int groups = 1; +static const int featureRowCount = 100000; + +template +using MulStrategyMap = + std::unordered_map, + LayerDimensionsHash>; + +template +const Operation &getHalideMul(int inFeatureCount, int outFeatureCount, + int groups, bool cuda, + MulStrategyMap &container) { + const LayerDimensions dims = {inFeatureCount, outFeatureCount, groups, cuda}; + auto it = container.find(dims); + + if (it != container.end()) { + return *(it->second); + } + + auto mul = + container.insert(std::make_pair(dims, std::make_unique(dims))) + .first->second.get(); + return *mul; +} + +struct HalideMulFactory::Impl { + MulStrategyMap backward; + MulStrategyMap forward; +}; + +HalideMulFactory::HalideMulFactory() : pimpl(new Impl()) {} + +HalideMulFactory::~HalideMulFactory() = default; + +const HalideMulFactory &HalideMulFactory::getInstance() { + static HalideMulFactory instance; + return instance; +} + +const HalideMulForward & +HalideMulFactory::getHalideMulForward(int inFeatureCount, int outFeatureCount, + int groups, bool cuda) const { + return getHalideMul(inFeatureCount, outFeatureCount, groups, + cuda, pimpl->forward); +} + +const HalideMulBackward & +HalideMulFactory::getHalideMulBackward(int inFeatureCount, int outFeatureCount, + int groups, bool cuda) const { + return getHalideMul(inFeatureCount, outFeatureCount, + groups, cuda, pimpl->backward); +} + +HalideMul::HalideMul(int inFeatureCount, int outFeatureCount, int groups) + : dimensions({inFeatureCount, outFeatureCount, groups}) {} + +HalideMul::HalideMul(const LayerDimensions &dims) : dimensions(dims) {} + +HalideMul::~HalideMul() = default; + +/* Implementation of forward Halide matrix multiplication */ +struct HalideMulForward::Impl { +public: + Impl(const LayerDimensions &dimensions, bool cuda) { + Halide::Target target = Halide::get_host_target(); + Halide::Func matmul = Halide::Func("matmul"); + + /* Variables */ + Halide::Var i, g, j; + Halide::RDom k{0, dimensions.inFeatureCount / dimensions.groups}; + + /* Algorithm */ + Halide::Expr producer = clamp(rules(2 * i), 0, maxHalideRow - 1); + matmul(j, i, g) = sum(inputFeatures(k, g, producer) * weights(j, k, g)); + + /* Schedule */ + matmul.estimate(j, 0, featureCount) + .estimate(g, 0, groups) + .estimate(i, 0, featureRowCount); + + inputFeatures.dim(0).set_bounds_estimate(0, featureCount); + inputFeatures.dim(1).set_bounds_estimate(0, groups); + inputFeatures.dim(2).set_bounds_estimate(0, featureRowCount); + + weights.dim(0).set_bounds_estimate(0, featureCount); + weights.dim(1).set_bounds_estimate(0, featureCount); + weights.dim(2).set_bounds_estimate(0, groups); + + rules.dim(0).set_bounds_estimate(0, activeRows); + activeRowsParam.set_estimate(activeRows); + + p = Halide::Pipeline({matmul}); + + if (!cuda) { + p.auto_schedule(target); + } else { + target.set_feature(Halide::Target::CUDA); + } + + p.compile_jit(target); + }; + + Halide::ImageParam inputFeatures = + Halide::ImageParam(Halide::type_of(), 3, "source"); + Halide::ImageParam weights = + Halide::ImageParam(Halide::type_of(), 3, "weight"); + Halide::ImageParam rules = + Halide::ImageParam(Halide::type_of(), 1, "rules"); + + Halide::Param activeRowsParam = Halide::Param("row_count"); + + Halide::Pipeline p; +}; + +HalideMulForward::HalideMulForward(int inFeatureCount, int outFeatureCount, + int groups, bool cuda) + : HalideMul(inFeatureCount, outFeatureCount, groups), + pimpl(new Impl(dimensions, cuda)) {} + +HalideMulForward::HalideMulForward(const LayerDimensions &dims) + : HalideMul(dims), pimpl(new Impl(dimensions, dims.cuda)) {} + +HalideMulForward::~HalideMulForward() = default; + +/* Executes the forward matrix multiplication created through the + implementation object. */ +void HalideMulForward::execute(float *input, float *weight, int *rules, + float *output, int activeRowCount) const { + + int inputPlanes = dimensions.inFeatureCount / dimensions.groups; + int outputPlanes = dimensions.outFeatureCount / dimensions.groups; + + pimpl->inputFeatures.set(Halide::Buffer( + input, inputPlanes, dimensions.groups, maxHalideRow)); + pimpl->weights.set(Halide::Buffer(weight, outputPlanes, inputPlanes, + dimensions.groups)); + pimpl->rules.set(Halide::Buffer(rules, 2 * activeRowCount)); + pimpl->activeRowsParam.set(activeRowCount); + + auto out = Halide::Buffer(output, outputPlanes, activeRowCount, + dimensions.groups); + pimpl->p.realize(out); +} + +/* Implementation of backward Halide matrix multiplication */ +struct HalideMulBackward::Impl { +public: + Impl(const LayerDimensions &dimensions, bool cuda) { + Halide::Target target = Halide::get_host_target(); + + int outputPlanes = dimensions.outFeatureCount / dimensions.groups; + + /* Variables */ + Halide::Func o_matmul = Halide::Func("o_matmul"); + Halide::Func o_weights = Halide::Func("o_weights"); + Halide::Var i, g, k, j, gw, outp, inp; + + Halide::RDom planes = Halide::RDom(0, outputPlanes); + Halide::RDom nums = Halide::RDom(0, activeRowsParam); + + /* Algorithm */ + Halide::Expr producer = clamp(rules(2 * i + 1), 0, maxHalideRow - 1); + + Halide::Expr orAccess_dom = clamp(rules(2 * nums + 1), 0, maxHalideRow - 1); + Halide::Expr irAccess_dom = clamp(rules(2 * nums), 0, maxHalideRow - 1); + + o_matmul(k, i, g) = + sum(weights(planes, k, g) * outputFeatures(planes, g, producer)); + + o_weights(outp, inp, gw) = sum(outputFeatures(outp, gw, orAccess_dom) * + inputFeatures(inp, gw, irAccess_dom)); + + /* Schedule */ + o_matmul.estimate(k, 0, featureCount) + .estimate(g, 0, groups) + .estimate(i, 0, featureRowCount); + o_weights.estimate(gw, 0, groups) + .estimate(outp, 0, featureCount) + .estimate(inp, 0, featureCount); + + inputFeatures.dim(0).set_bounds_estimate(0, featureCount); + inputFeatures.dim(1).set_bounds_estimate(0, groups); + inputFeatures.dim(2).set_bounds_estimate(0, featureRowCount); + + outputFeatures.dim(0).set_bounds_estimate(0, featureCount); + outputFeatures.dim(1).set_bounds_estimate(0, groups); + outputFeatures.dim(2).set_bounds_estimate(0, featureRowCount); + + weights.dim(0).set_bounds_estimate(0, featureCount); + weights.dim(1).set_bounds_estimate(0, featureCount); + weights.dim(2).set_bounds_estimate(0, groups); + + rules.dim(0).set_bounds_estimate(0, activeRows); + activeRowsParam.set_estimate(activeRows); + + p = Halide::Pipeline({o_matmul, o_weights}); + + if (cuda) { + target.set_feature(Halide::Target::CUDA); + } else { + p.auto_schedule(target); + } + + p.compile_jit(target); + }; + + Halide::ImageParam inputFeatures = + Halide::ImageParam(Halide::type_of(), 3, "input_features"); + Halide::ImageParam outputFeatures = + Halide::ImageParam(Halide::type_of(), 3, "output_features"); + Halide::ImageParam rules = + Halide::ImageParam(Halide::type_of(), 1, "rules"); + Halide::ImageParam weights = + Halide::ImageParam(Halide::type_of(), 3, "weights"); + + Halide::Param activeRowsParam = Halide::Param("row_count"); + + Halide::Pipeline p; +}; + +HalideMulBackward::HalideMulBackward(int inFeatureCount, int outFeatureCount, + int groups, bool cuda) + : HalideMul(inFeatureCount, outFeatureCount, groups), + pimpl(new Impl(dimensions, cuda)) {} + +HalideMulBackward::HalideMulBackward(const LayerDimensions &dims) + : HalideMul(dims), pimpl(new Impl(dimensions, dims.cuda)) {} + +HalideMulBackward::~HalideMulBackward() = default; + +/* Executes the backward matrix multiplications created through the + implementation object. */ +void HalideMulBackward::execute(float *inputFeatures, float *outputFeatures, + int *rules, float *weights, + float *dWeightsOutput, float *output, + int activeRowCount) const { + + int inputPlanes = dimensions.inFeatureCount / dimensions.groups; + int outputPlanes = dimensions.outFeatureCount / dimensions.groups; + + pimpl->inputFeatures.set(Halide::Buffer( + inputFeatures, inputPlanes, dimensions.groups, maxHalideRow)); + pimpl->outputFeatures.set(Halide::Buffer( + outputFeatures, outputPlanes, dimensions.groups, maxHalideRow)); + pimpl->weights.set(Halide::Buffer(weights, outputPlanes, inputPlanes, + dimensions.groups)); + pimpl->rules.set(Halide::Buffer(rules, 2 * activeRowCount)); + + pimpl->activeRowsParam.set(activeRowCount); + + auto halideOutput = Halide::Buffer(output, inputPlanes, activeRowCount, + dimensions.groups); + auto halideWOutput = Halide::Buffer(dWeightsOutput, outputPlanes, + inputPlanes, dimensions.groups); + + pimpl->p.realize({halideOutput, halideWOutput}); +} diff --git a/sparseconvnet/SCN/Halide/mul.hpp b/sparseconvnet/SCN/Halide/mul.hpp new file mode 100644 index 0000000..6adca1c --- /dev/null +++ b/sparseconvnet/SCN/Halide/mul.hpp @@ -0,0 +1,141 @@ +#ifndef MUL_H_ +#define MUL_H_ + +#include + +class HalideMul; +class HalideMulBackward; +class HalideMulForward; + +struct LayerDimensions { + int inFeatureCount; + int outFeatureCount; + int groups; + bool cuda; + + bool operator==(const LayerDimensions &that) const { + return inFeatureCount == that.inFeatureCount && + outFeatureCount == that.outFeatureCount && groups == that.groups && + cuda == that.cuda; + } +}; + +struct LayerDimensionsHash { + std::size_t operator()(const LayerDimensions &dims) const { + std::size_t seed = 16777619; + + combineHash(seed, dims.inFeatureCount); + combineHash(seed, dims.outFeatureCount); + combineHash(seed, dims.groups); + combineHash(seed, dims.cuda); + + return seed; + } + +private: + void combineHash(std::size_t &seed, int value) const { + seed ^= value + 0x9e3779b9 + (seed << 6) + (seed >> 2); + } +}; + +/* Singleton for caching instances of Halide matrix multiplication. */ +class HalideMulFactory { +public: + ~HalideMulFactory(); + + static const HalideMulFactory &getInstance(); + + const HalideMulForward &getHalideMulForward(int inFeatureCount, + int outFeatureCount, int groups, + bool cuda) const; + + const HalideMulBackward &getHalideMulBackward(int inFeatureCount, + int outFeatureCount, int groups, + bool cuda) const; + +private: + HalideMulFactory(); + + HalideMulFactory(HalideMulFactory const &); + void operator=(HalideMulFactory const &); + + struct Impl; + const std::unique_ptr pimpl; +}; + +/* Sets up the dimensions of the layer. An instance needs to be created + once for every layer to set up the parameters of the halide + function based on the properties of the layer. The halide algorithm + and schedule are set up at construction in the child instances. */ +class HalideMul { +public: + HalideMul(int inFeatureCount, int outFeatureCount, int groups); + + HalideMul(const LayerDimensions &dims); + + ~HalideMul(); + +protected: + const LayerDimensions dimensions; +}; + +class HalideMulForward : public HalideMul { +public: + HalideMulForward(int inFeatureCount, int outFeatureCount, int groups, + bool cuda); + + HalideMulForward(const LayerDimensions &dims); + + ~HalideMulForward(); + + /* Executes forward matrix multiplication for a single filter offset (as per + rulebook implementation). Due to Halide's column major indexing, + the used dimensions are: + + input = input_planes x groups x input_row_count + weight = output_planes x input_planes x groups + rules = 2 * active_row_count (in a single dimension) + output = output_planes x active_row_count x groups + + To correctly write to the output feature matrix, use the + rule_index_add_() function with the obtained output. */ + void execute(float *input, float *weight, int *rules, float *output, + int activeRowCount) const; + +private: + struct Impl; + const std::unique_ptr pimpl; +}; + +class HalideMulBackward : public HalideMul { +public: + HalideMulBackward(int inFeatureCount, int outFeatureCount, int groups, + bool cuda); + + HalideMulBackward(const LayerDimensions &dims); + + ~HalideMulBackward(); + + /* Executes backward matrix multiplication for a single filter offset (as per + rulebook implementation). Due to Halide's column major indexing, + the used dimensions are: + + inputFeatures = input_planes x groups x input_row_count + outputFeatures = output_planes x groups x output_row_count + rules = 2 * active_row_count (in a single dimension) + weights = output_planes x input_planes x groups + d_weights_output = output_planes x input_planes x groups + input_rows_output = input_planes x active_row_count x groups + + To correctly write to the input feature matrix, use the + rule_index_add_() function with the obtained input_rows_output. */ + void execute(float *inputFeatures, float *outputFeatures, int *rules, + float *weights, float *dWeightsOutput, float *output, + int activeRowCount) const; + +private: + struct Impl; + const std::unique_ptr pimpl; +}; + +#endif // !MUL_H_