/* * BMatchingSolver.cpp * * Created on: Dec 8, 2010 * Author: bert */ #include #include #include "BMatchingLibrary.h" using namespace std; SparseMatrix *bmatchingLibrary::bMatchBipartiteMatrixSparse(int rows, int cols, double **W, int *brows, int *bcols) { return bMatchBipartiteMatrixSparse(rows, cols, W, brows, bcols, DEFAULT_MAX_ITERS); } SparseMatrix *bmatchingLibrary::bMatchBipartiteMatrixSparse(int rows, int cols, double **W, int *brows, int *bcols, int maxIters) { return bMatchBipartiteMatrixCache(rows, cols, W, brows, bcols, rows + cols, maxIters); } SparseMatrix *bmatchingLibrary::bMatchBipartiteMatrixCache(int rows, int cols, double **W, int *brows, int *bcols, int cacheSize, int maxIters, bool verbose) { int * b = cat(brows, bcols, rows, cols); WeightOracle * wo = new BipartiteMatrixOracle(W, rows, cols); wo->setCacheSize(cacheSize); SparseMatrix * ret = solveBMatchingSparse(wo, b, maxIters, verbose); delete[] (b); delete (wo); return ret; } bool **bmatchingLibrary::bMatchBipartiteEuclidean(int rows, int cols, int d, double **X, double **Y, int *brows, int *bcols) { return bMatchBipartiteEuclidean(rows, cols, d, X, Y, brows, bcols, DEFAULT_MAX_ITERS); } SparseMatrix *bmatchingLibrary::bMatchBipartiteEuclideanSparse(int rows, int cols, int d, double **X, double **Y, int *brows, int *bcols) { return bMatchBipartiteEuclideanSparse(rows, cols, d, X, Y, brows, bcols, DEFAULT_MAX_ITERS); } SparseMatrix *bmatchingLibrary::bMatchBipartiteEuclideanSparse(int rows, int cols, int d, double **X, double **Y, int *brows, int *bcols, int maxIters) { return bMatchBipartiteEuclideanCache(rows, cols, d, X, Y, brows, bcols, rows + cols, maxIters); } SparseMatrix *bmatchingLibrary::bMatchBipartiteEuclideanCache(int rows, int cols, int d, double **X, double **Y, int *brows, int *bcols, int cacheSize, int maxIters, bool verbose) { int * b = cat(brows, bcols, rows, cols); WeightFunction * wf = new EuclideanDistance(); WeightOracle * wo = new BipartiteFunctionOracle(X, Y, wf, rows, cols, d); wo->setCacheSize(cacheSize); SparseMatrix * ret = solveBMatchingSparse(wo, b, maxIters, verbose); delete[] (b); delete(wf); delete(wo); return ret; } SparseMatrix *bmatchingLibrary::bMatchBipartiteInnerProductSparse( int rows, int cols, int d, double **X, double **Y, int *brows, int *bcols) { return bMatchBipartiteInnerProductSparse(rows, cols, d, X, Y, brows, bcols, DEFAULT_MAX_ITERS); } SparseMatrix *bmatchingLibrary::bMatchBipartiteInnerProductSparse( int rows, int cols, int d, double **X, double **Y, int *brows, int *bcols, int maxIters) { return bMatchBipartiteInnerProductCache(rows, cols, d, X, Y, brows, bcols, rows + cols, maxIters); } SparseMatrix *bmatchingLibrary::bMatchBipartiteInnerProductCache( int rows, int cols, int d, double **X, double **Y, int *brows, int *bcols, int cacheSize, int maxIters, bool verbose) { int * b = cat(brows, bcols, rows, cols); WeightFunction * wf = new InnerProduct(); WeightOracle * wo = new BipartiteFunctionOracle(X, Y, wf, rows, cols, d); wo->setCacheSize(cacheSize); SparseMatrix * ret = solveBMatchingSparse(wo, b, maxIters, verbose); delete[] (b); delete(wf); delete(wo); return ret; } SparseMatrix *bmatchingLibrary::bMatchMatrixSparse(int size, double **W, int *b) { return bMatchMatrixSparse(size, W, b, DEFAULT_MAX_ITERS); } SparseMatrix *bmatchingLibrary::bMatchMatrixSparse(int size, double **W, int *b, int maxIters) { return bMatchMatrixCache(size, W, b, size, maxIters); } SparseMatrix *bmatchingLibrary::bMatchMatrixCache(int size, double **W, int *b, int cacheSize, int maxIters, bool verbose) { WeightOracle * wo = new MatrixOracle(size, W); wo->setCacheSize(cacheSize); SparseMatrix * ret = solveBMatchingSparse(wo, b, maxIters, verbose); delete (wo); return ret; } SparseMatrix *bmatchingLibrary::bMatchEuclideanSparse(int size, int d, double **X, int *b) { return bMatchEuclideanSparse(size, d, X, b, DEFAULT_MAX_ITERS); } SparseMatrix *bmatchingLibrary::bMatchEuclideanSparse(int size, int d, double **X, int *b, int maxIters) { return bMatchEuclideanCache(size, d, X, b, size, maxIters); } SparseMatrix *bmatchingLibrary::bMatchEuclideanCache(int size, int d, double **X, int *b, int cacheSize, int maxIters, bool verbose) { WeightFunction * wf = new EuclideanDistance(); WeightOracle * wo = new FunctionOracle(X, wf, size, d); wo->setCacheSize(cacheSize); SparseMatrix * ret = solveBMatchingSparse(wo, b, maxIters, verbose); delete (wf); delete (wo); return ret; } SparseMatrix *bmatchingLibrary::bMatchInnerProductSparse(int size, int d, double **X, int *b) { return bMatchInnerProductSparse(size, d, X, b, DEFAULT_MAX_ITERS); } SparseMatrix *bmatchingLibrary::bMatchInnerProductSparse(int size, int d, double **X, int *b, int maxIters) { return bMatchInnerProductCache(size, d, X, b, size, maxIters); } SparseMatrix *bmatchingLibrary::bMatchInnerProductCache(int size, int d, double **X, int *b, int cacheSize, int maxIters, bool verbose) { WeightFunction * wf = new InnerProduct(); WeightOracle * wo = new FunctionOracle(X, wf, size, d); wo->setCacheSize(cacheSize); SparseMatrix * ret = solveBMatchingSparse(wo, b, maxIters, verbose); delete (wf); delete (wo); return ret; } // nonsparse versions bool **bmatchingLibrary::bMatchBipartiteInnerProduct(int rows, int cols, int d, double **X, double **Y, int *brows, int *bcols) { return bMatchBipartiteInnerProduct(rows, cols, d, X, Y, brows, bcols, DEFAULT_MAX_ITERS); } bool **bmatchingLibrary::bMatchMatrix(int size, double **W, int *b, int maxIters) { SparseMatrix * solution = bMatchMatrixSparse(size, W, b, maxIters); bool ** ret = convertToBool(solution); delete (solution); return ret; } bool **bmatchingLibrary::bMatchEuclidean(int size, int d, double **X, int *b) { return bMatchEuclidean(size, d, X, b, DEFAULT_MAX_ITERS); } bool **bmatchingLibrary::bMatchInnerProduct(int size, int d, double **X, int *b) { return bMatchInnerProduct(size, d, X, b, DEFAULT_MAX_ITERS); } bool **bmatchingLibrary::bMatchBipartiteMatrix(int rows, int cols, double **W, int *brows, int *bcols) { return bMatchBipartiteMatrix(rows, cols, W, brows, bcols, DEFAULT_MAX_ITERS); } bool **bmatchingLibrary::bMatchEuclidean(int size, int d, double **X, int *b, int maxIters) { SparseMatrix * solution = bMatchEuclideanSparse(size, d, X, b, maxIters); bool ** ret = convertToBool(solution); delete (solution); return ret; } bool **bmatchingLibrary::bMatchBipartiteMatrix(int rows, int cols, double **W, int *brows, int *bcols, int maxIters) { SparseMatrix * solution = bMatchBipartiteMatrixSparse(rows, cols, W, brows, bcols, maxIters); bool ** ret = convertToBool(solution); delete (solution); return ret; } bool **bmatchingLibrary::bMatchBipartiteEuclidean(int rows, int cols, int d, double **X, double **Y, int *brows, int *bcols, int maxIters) { SparseMatrix * solution = bMatchBipartiteEuclideanSparse(rows, cols, d, X, Y, brows, bcols, maxIters); bool ** ret = convertToBool(solution); delete (solution); return ret; } bool **bmatchingLibrary::bMatchMatrix(int size, double **W, int *b) { return bMatchMatrix(size, W, b, DEFAULT_MAX_ITERS); } bool **bmatchingLibrary::bMatchBipartiteInnerProduct(int rows, int cols, int d, double **X, double **Y, int *brows, int *bcols, int maxIters) { SparseMatrix * solution = bMatchBipartiteInnerProductSparse(rows, cols, d, X, Y, brows, bcols, maxIters); bool ** ret = convertToBool(solution); delete (solution); return ret; } bool **bmatchingLibrary::bMatchInnerProduct(int size, int d, double **X, int *b, int maxIters) { SparseMatrix * solution = bMatchInnerProductSparse(size, d, X, b, maxIters); bool ** ret = convertToBool(solution); delete (solution); return ret; } bool **bmatchingLibrary::convertToBool(SparseMatrix * solution) { int rows = solution->getRowCount(); int cols = solution->getColCount(); bool ** ret = new bool*[rows]; for (int i = 0; i < rows; i++) { ret[i] = new bool[cols]; std::fill(ret[i], ret[i] + cols, false); } int * nzrows = solution->getRows(); int * nzcols = solution->getCols(); for (int i = 0; i < solution->getNNz(); i++) { ret[nzrows[i]][nzcols[i]] = true; } return ret; } SparseMatrix *bmatchingLibrary::solveBMatchingSparse(WeightOracle *wo, int *b, int maxIters, bool verbose) { time_t start, end; time(&start); wo->computeIndex(); time(&end); if (verbose) { cout << "Computing cache took " << difftime(end, start) << " s. Starting belief propagation..." << endl; } BMatchingProblem * bmatching = new BMatchingProblem(wo, wo->getSize(), b, verbose); bmatching->setMaxIter(maxIters); SparseMatrix * solution; solution = bmatching->solve(); delete (bmatching); if (verbose) { wo->printStatsString(); } return solution; } int *bmatchingLibrary::cat(int *a, int *b, int aSize, int bSize) { int * ret = new int[aSize + bSize]; for (int i = 0; i < aSize; i++) ret[i] = a[i]; for (int i = 0; i < bSize; i++) ret[aSize + i] = b[i]; return ret; }