// Copyright (C) 2011 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_STRUCTURAL_ASSiGNMENT_TRAINER_Hh_
#define DLIB_STRUCTURAL_ASSiGNMENT_TRAINER_Hh_
#include "structural_assignment_trainer_abstract.h"
#include "../algs.h"
#include "../optimization.h"
#include "structural_svm_assignment_problem.h"
#include "num_nonnegative_weights.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
template <
typename feature_extractor
>
class structural_assignment_trainer
{
public:
typedef typename feature_extractor::lhs_element lhs_element;
typedef typename feature_extractor::rhs_element rhs_element;
typedef std::pair<std::vector<lhs_element>, std::vector<rhs_element> > sample_type;
typedef std::vector<long> label_type;
typedef assignment_function<feature_extractor> trained_function_type;
structural_assignment_trainer (
)
{
set_defaults();
}
explicit structural_assignment_trainer (
const feature_extractor& fe_
) : fe(fe_)
{
set_defaults();
}
const feature_extractor& get_feature_extractor (
) const { return fe; }
void set_num_threads (
unsigned long num
)
{
num_threads = num;
}
unsigned long get_num_threads (
) const
{
return num_threads;
}
void set_epsilon (
double eps_
)
{
// make sure requires clause is not broken
DLIB_ASSERT(eps_ > 0,
"\t void structural_assignment_trainer::set_epsilon()"
<< "\n\t eps_ must be greater than 0"
<< "\n\t eps_: " << eps_
<< "\n\t this: " << this
);
eps = eps_;
}
double get_epsilon (
) const { return eps; }
void set_max_cache_size (
unsigned long max_size
)
{
max_cache_size = max_size;
}
unsigned long get_max_cache_size (
) const
{
return max_cache_size;
}
void be_verbose (
)
{
verbose = true;
}
void be_quiet (
)
{
verbose = false;
}
void set_oca (
const oca& item
)
{
solver = item;
}
const oca get_oca (
) const
{
return solver;
}
void set_c (
double C_
)
{
// make sure requires clause is not broken
DLIB_ASSERT(C_ > 0,
"\t void structural_assignment_trainer::set_c()"
<< "\n\t C_ must be greater than 0"
<< "\n\t C_: " << C_
<< "\n\t this: " << this
);
C = C_;
}
double get_c (
) const
{
return C;
}
bool forces_assignment(
) const { return force_assignment; }
void set_forces_assignment (
bool new_value
)
{
force_assignment = new_value;
}
void set_loss_per_false_association (
double loss
)
{
// make sure requires clause is not broken
DLIB_ASSERT(loss > 0,
"\t void structural_assignment_trainer::set_loss_per_false_association(loss)"
<< "\n\t Invalid inputs were given to this function "
<< "\n\t loss: " << loss
<< "\n\t this: " << this
);
loss_per_false_association = loss;
}
double get_loss_per_false_association (
) const
{
return loss_per_false_association;
}
void set_loss_per_missed_association (
double loss
)
{
// make sure requires clause is not broken
DLIB_ASSERT(loss > 0,
"\t void structural_assignment_trainer::set_loss_per_missed_association(loss)"
<< "\n\t Invalid inputs were given to this function "
<< "\n\t loss: " << loss
<< "\n\t this: " << this
);
loss_per_missed_association = loss;
}
double get_loss_per_missed_association (
) const
{
return loss_per_missed_association;
}
bool forces_last_weight_to_1 (
) const
{
return last_weight_1;
}
void force_last_weight_to_1 (
bool should_last_weight_be_1
)
{
last_weight_1 = should_last_weight_be_1;
}
const assignment_function<feature_extractor> train (
const std::vector<sample_type>& samples,
const std::vector<label_type>& labels
) const
{
// make sure requires clause is not broken
#ifdef ENABLE_ASSERTS
if (force_assignment)
{
DLIB_ASSERT(is_forced_assignment_problem(samples, labels),
"\t assignment_function structural_assignment_trainer::train()"
<< "\n\t invalid inputs were given to this function"
<< "\n\t is_forced_assignment_problem(samples,labels): " << is_forced_assignment_problem(samples,labels)
<< "\n\t is_assignment_problem(samples,labels): " << is_assignment_problem(samples,labels)
<< "\n\t is_learning_problem(samples,labels): " << is_learning_problem(samples,labels)
);
}
else
{
DLIB_ASSERT(is_assignment_problem(samples, labels),
"\t assignment_function structural_assignment_trainer::train()"
<< "\n\t invalid inputs were given to this function"
<< "\n\t is_assignment_problem(samples,labels): " << is_assignment_problem(samples,labels)
<< "\n\t is_learning_problem(samples,labels): " << is_learning_problem(samples,labels)
);
}
#endif
structural_svm_assignment_problem<feature_extractor> prob(samples,labels, fe, force_assignment, num_threads,
loss_per_false_association, loss_per_missed_association);
if (verbose)
prob.be_verbose();
prob.set_c(C);
prob.set_epsilon(eps);
prob.set_max_cache_size(max_cache_size);
matrix<double,0,1> weights;
// Take the min here because we want to prevent the user from accidentally
// forcing the bias term to be non-negative.
const unsigned long num_nonneg = std::min(fe.num_features(),num_nonnegative_weights(fe));
if (last_weight_1)
solver(prob, weights, num_nonneg, fe.num_features()-1);
else
solver(prob, weights, num_nonneg);
const double bias = weights(weights.size()-1);
return assignment_function<feature_extractor>(colm(weights,0,weights.size()-1), bias,fe,force_assignment);
}
private:
bool force_assignment;
double C;
oca solver;
double eps;
bool verbose;
unsigned long num_threads;
unsigned long max_cache_size;
double loss_per_false_association;
double loss_per_missed_association;
bool last_weight_1;
void set_defaults ()
{
force_assignment = false;
C = 100;
verbose = false;
eps = 0.01;
num_threads = 2;
max_cache_size = 5;
loss_per_false_association = 1;
loss_per_missed_association = 1;
last_weight_1 = false;
}
feature_extractor fe;
};
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_STRUCTURAL_ASSiGNMENT_TRAINER_Hh_