// Copyright (C) 2011 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_MAX_COST_ASSIgNMENT_Hh_
#define DLIB_MAX_COST_ASSIgNMENT_Hh_
#include "max_cost_assignment_abstract.h"
#include "../matrix.h"
#include <vector>
#include <deque>
namespace dlib
{
// ----------------------------------------------------------------------------------------
template <typename EXP>
typename EXP::type assignment_cost (
const matrix_exp<EXP>& cost,
const std::vector<long>& assignment
)
{
DLIB_ASSERT(cost.nr() == cost.nc(),
"\t type assignment_cost(cost,assignment)"
<< "\n\t cost.nr(): " << cost.nr()
<< "\n\t cost.nc(): " << cost.nc()
);
#ifdef ENABLE_ASSERTS
// can't call max on an empty vector. So put an if here to guard against it.
if (assignment.size() > 0)
{
DLIB_ASSERT(0 <= min(mat(assignment)) && max(mat(assignment)) < cost.nr(),
"\t type assignment_cost(cost,assignment)"
<< "\n\t cost.nr(): " << cost.nr()
<< "\n\t cost.nc(): " << cost.nc()
<< "\n\t min(assignment): " << min(mat(assignment))
<< "\n\t max(assignment): " << max(mat(assignment))
);
}
#endif
typename EXP::type temp = 0;
for (unsigned long i = 0; i < assignment.size(); ++i)
{
temp += cost(i, assignment[i]);
}
return temp;
}
// ----------------------------------------------------------------------------------------
namespace impl
{
template <typename EXP>
inline void compute_slack(
const long x,
std::vector<typename EXP::type>& slack,
std::vector<long>& slackx,
const matrix_exp<EXP>& cost,
const std::vector<typename EXP::type>& lx,
const std::vector<typename EXP::type>& ly
)
{
for (long y = 0; y < cost.nc(); ++y)
{
if (lx[x] + ly[y] - cost(x,y) < slack[y])
{
slack[y] = lx[x] + ly[y] - cost(x,y);
slackx[y] = x;
}
}
}
}
// ----------------------------------------------------------------------------------------
template <typename EXP>
std::vector<long> max_cost_assignment (
const matrix_exp<EXP>& cost_
)
{
const_temp_matrix<EXP> cost(cost_);
typedef typename EXP::type type;
// This algorithm only works if the elements of the cost matrix can be reliably
// compared using operator==. However, comparing for equality with floating point
// numbers is not a stable operation. So you need to use an integer cost matrix.
COMPILE_TIME_ASSERT(std::numeric_limits<type>::is_integer);
DLIB_ASSERT(cost.nr() == cost.nc(),
"\t std::vector<long> max_cost_assignment(cost)"
<< "\n\t cost.nr(): " << cost.nr()
<< "\n\t cost.nc(): " << cost.nc()
);
using namespace dlib::impl;
/*
I based the implementation of this algorithm on the description of the
Hungarian algorithm on the following websites:
http://www.math.uwo.ca/~mdawes/courses/344/kuhn-munkres.pdf
http://www.topcoder.com/tc?module=Static&d1=tutorials&d2=hungarianAlgorithm
Note that this is the fast O(n^3) version of the algorithm.
*/
if (cost.size() == 0)
return std::vector<long>();
std::vector<type> lx, ly;
std::vector<long> xy;
std::vector<long> yx;
std::vector<char> S, T;
std::vector<type> slack;
std::vector<long> slackx;
std::vector<long> aug_path;
// Initially, nothing is matched.
xy.assign(cost.nc(), -1);
yx.assign(cost.nc(), -1);
/*
We maintain the following invariant:
Vertex x is matched to vertex xy[x] and
vertex y is matched to vertex yx[y].
A value of -1 means a vertex isn't matched to anything. Moreover,
x corresponds to rows of the cost matrix and y corresponds to the
columns of the cost matrix. So we are matching X to Y.
*/
// Create an initial feasible labeling. Moreover, in the following
// code we will always have:
// for all valid x and y: lx[x] + ly[y] >= cost(x,y)
lx.resize(cost.nc());
ly.assign(cost.nc(),0);
for (long x = 0; x < cost.nr(); ++x)
lx[x] = max(rowm(cost,x));
// Now grow the match set by picking edges from the equality subgraph until
// we have a complete matching.
for (long match_size = 0; match_size < cost.nc(); ++match_size)
{
std::deque<long> q;
// Empty out the S and T sets
S.assign(cost.nc(), false);
T.assign(cost.nc(), false);
// clear out old slack values
slack.assign(cost.nc(), std::numeric_limits<type>::max());
slackx.resize(cost.nc());
/*
slack and slackx are maintained such that we always
have the following (once they get initialized by compute_slack() below):
- for all y:
- let x == slackx[y]
- slack[y] == lx[x] + ly[y] - cost(x,y)
*/
aug_path.assign(cost.nc(), -1);
for (long x = 0; x < cost.nc(); ++x)
{
// If x is not matched to anything
if (xy[x] == -1)
{
q.push_back(x);
S[x] = true;
compute_slack(x, slack, slackx, cost, lx, ly);
break;
}
}
long x_start = 0;
long y_start = 0;
// Find an augmenting path.
bool found_augmenting_path = false;
while (!found_augmenting_path)
{
while (q.size() > 0 && !found_augmenting_path)
{
const long x = q.front();
q.pop_front();
for (long y = 0; y < cost.nc(); ++y)
{
if (cost(x,y) == lx[x] + ly[y] && !T[y])
{
// if vertex y isn't matched with anything
if (yx[y] == -1)
{
y_start = y;
x_start = x;
found_augmenting_path = true;
break;
}
T[y] = true;
q.push_back(yx[y]);
aug_path[yx[y]] = x;
S[yx[y]] = true;
compute_slack(yx[y], slack, slackx, cost, lx, ly);
}
}
}
if (found_augmenting_path)
break;
// Since we didn't find an augmenting path we need to improve the
// feasible labeling stored in lx and ly. We also need to keep the
// slack updated accordingly.
type delta = std::numeric_limits<type>::max();
for (unsigned long i = 0; i < T.size(); ++i)
{
if (!T[i])
delta = std::min(delta, slack[i]);
}
for (unsigned long i = 0; i < T.size(); ++i)
{
if (S[i])
lx[i] -= delta;
if (T[i])
ly[i] += delta;
else
slack[i] -= delta;
}
q.clear();
for (long y = 0; y < cost.nc(); ++y)
{
if (!T[y] && slack[y] == 0)
{
// if vertex y isn't matched with anything
if (yx[y] == -1)
{
x_start = slackx[y];
y_start = y;
found_augmenting_path = true;
break;
}
else
{
T[y] = true;
if (!S[yx[y]])
{
q.push_back(yx[y]);
aug_path[yx[y]] = slackx[y];
S[yx[y]] = true;
compute_slack(yx[y], slack, slackx, cost, lx, ly);
}
}
}
}
} // end while (!found_augmenting_path)
// Flip the edges along the augmenting path. This means we will add one more
// item to our matching.
for (long cx = x_start, cy = y_start, ty;
cx != -1;
cx = aug_path[cx], cy = ty)
{
ty = xy[cx];
yx[cy] = cx;
xy[cx] = cy;
}
}
return xy;
}
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_MAX_COST_ASSIgNMENT_Hh_