// Copyright (C) 2012 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_KALMAN_FiLTER_Hh_
#define DLIB_KALMAN_FiLTER_Hh_
#include "kalman_filter_abstract.h"
#include "../matrix.h"
#include "../geometry.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
template <
long states,
long measurements
>
class kalman_filter
{
public:
kalman_filter()
{
H = 0;
A = 0;
Q = 0;
R = 0;
x = 0;
xb = 0;
P = identity_matrix<double>(states);
got_first_meas = false;
}
void set_observation_model ( const matrix<double,measurements,states>& H_) { H = H_; }
void set_transition_model ( const matrix<double,states,states>& A_) { A = A_; }
void set_process_noise ( const matrix<double,states,states>& Q_) { Q = Q_; }
void set_measurement_noise ( const matrix<double,measurements,measurements>& R_) { R = R_; }
void set_estimation_error_covariance( const matrix<double,states,states>& P_) { P = P_; }
void set_state ( const matrix<double,states,1>& xb_)
{
xb = xb_;
if (!got_first_meas)
{
x = xb_;
got_first_meas = true;
}
}
const matrix<double,measurements,states>& get_observation_model (
) const { return H; }
const matrix<double,states,states>& get_transition_model (
) const { return A; }
const matrix<double,states,states>& get_process_noise (
) const { return Q; }
const matrix<double,measurements,measurements>& get_measurement_noise (
) const { return R; }
void update (
)
{
// propagate estimation error covariance forward
P = A*P*trans(A) + Q;
// propagate state forward
x = xb;
xb = A*x;
}
void update (const matrix<double,measurements,1>& z)
{
// propagate estimation error covariance forward
P = A*P*trans(A) + Q;
// compute Kalman gain matrix
const matrix<double,states,measurements> K = P*trans(H)*pinv(H*P*trans(H) + R);
if (got_first_meas)
{
const matrix<double,measurements,1> res = z - H*xb;
// correct the current state estimate
x = xb + K*res;
}
else
{
// Since we don't have a previous state estimate at the start of filtering,
// we will just set the current state to whatever is indicated by the measurement
x = pinv(H)*z;
got_first_meas = true;
}
// propagate state forward in time
xb = A*x;
// update estimation error covariance since we got a measurement.
P = (identity_matrix<double,states>() - K*H)*P;
}
const matrix<double,states,1>& get_current_state(
) const
{
return x;
}
const matrix<double,states,1>& get_predicted_next_state(
) const
{
return xb;
}
const matrix<double,states,states>& get_current_estimation_error_covariance(
) const
{
return P;
}
friend inline void serialize(const kalman_filter& item, std::ostream& out)
{
int version = 1;
serialize(version, out);
serialize(item.got_first_meas, out);
serialize(item.x, out);
serialize(item.xb, out);
serialize(item.P, out);
serialize(item.H, out);
serialize(item.A, out);
serialize(item.Q, out);
serialize(item.R, out);
}
friend inline void deserialize(kalman_filter& item, std::istream& in)
{
int version = 0;
deserialize(version, in);
if (version != 1)
throw dlib::serialization_error("Unknown version number found while deserializing kalman_filter object.");
deserialize(item.got_first_meas, in);
deserialize(item.x, in);
deserialize(item.xb, in);
deserialize(item.P, in);
deserialize(item.H, in);
deserialize(item.A, in);
deserialize(item.Q, in);
deserialize(item.R, in);
}
private:
bool got_first_meas;
matrix<double,states,1> x, xb;
matrix<double,states,states> P;
matrix<double,measurements,states> H;
matrix<double,states,states> A;
matrix<double,states,states> Q;
matrix<double,measurements,measurements> R;
};
// ----------------------------------------------------------------------------------------
class momentum_filter
{
public:
momentum_filter(
double meas_noise,
double acc,
double max_meas_dev
) :
measurement_noise(meas_noise),
typical_acceleration(acc),
max_measurement_deviation(max_meas_dev)
{
DLIB_CASSERT(meas_noise >= 0);
DLIB_CASSERT(acc >= 0);
DLIB_CASSERT(max_meas_dev >= 0);
kal.set_observation_model({1, 0});
kal.set_transition_model( {1, 1,
0, 1});
kal.set_process_noise({0, 0,
0, typical_acceleration*typical_acceleration});
kal.set_measurement_noise({measurement_noise*measurement_noise});
}
momentum_filter() = default;
double get_measurement_noise (
) const { return measurement_noise; }
double get_typical_acceleration (
) const { return typical_acceleration; }
double get_max_measurement_deviation (
) const { return max_measurement_deviation; }
void reset()
{
*this = momentum_filter(measurement_noise, typical_acceleration, max_measurement_deviation);
}
double get_predicted_next_position(
) const
{
return kal.get_predicted_next_state()(0);
}
double operator()(
const double measured_position
)
{
auto x = kal.get_predicted_next_state();
const auto max_deviation = max_measurement_deviation*measurement_noise;
// Check if measured_position has suddenly jumped in value by a whole lot. This
// could happen if the velocity term experiences a much larger than normal
// acceleration, e.g. because the underlying object is doing a maneuver. If
// this happens then we clamp the state so that the predicted next value is no
// more than max_deviation away from measured_position at all times.
if (x(0) > measured_position + max_deviation)
{
x(0) = measured_position + max_deviation;
kal.set_state(x);
}
else if (x(0) < measured_position - max_deviation)
{
x(0) = measured_position - max_deviation;
kal.set_state(x);
}
kal.update({measured_position});
return kal.get_current_state()(0);
}
friend std::ostream& operator << (std::ostream& out, const momentum_filter& item)
{
out << "measurement_noise: " << item.measurement_noise << "\n";
out << "typical_acceleration: " << item.typical_acceleration << "\n";
out << "max_measurement_deviation: " << item.max_measurement_deviation;
return out;
}
friend void serialize(const momentum_filter& item, std::ostream& out)
{
int version = 15;
serialize(version, out);
serialize(item.measurement_noise, out);
serialize(item.typical_acceleration, out);
serialize(item.max_measurement_deviation, out);
serialize(item.kal, out);
}
friend void deserialize(momentum_filter& item, std::istream& in)
{
int version = 0;
deserialize(version, in);
if (version != 15)
throw serialization_error("Unexpected version found while deserializing momentum_filter.");
deserialize(item.measurement_noise, in);
deserialize(item.typical_acceleration, in);
deserialize(item.max_measurement_deviation, in);
deserialize(item.kal, in);
}
private:
double measurement_noise = 2;
double typical_acceleration = 0.1;
double max_measurement_deviation = 3; // nominally number of standard deviations
kalman_filter<2,1> kal;
};
// ----------------------------------------------------------------------------------------
momentum_filter find_optimal_momentum_filter (
const std::vector<std::vector<double>>& sequences,
const double smoothness = 1
);
// ----------------------------------------------------------------------------------------
momentum_filter find_optimal_momentum_filter (
const std::vector<double>& sequence,
const double smoothness = 1
);
// ----------------------------------------------------------------------------------------
class rect_filter
{
public:
rect_filter() = default;
rect_filter(
double meas_noise,
double acc,
double max_meas_dev
) : rect_filter(momentum_filter(meas_noise, acc, max_meas_dev)) {}
rect_filter(
const momentum_filter& filt
) :
left(filt),
top(filt),
right(filt),
bottom(filt)
{
}
drectangle operator()(const drectangle& r)
{
return drectangle(left(r.left()),
top(r.top()),
right(r.right()),
bottom(r.bottom()));
}
drectangle operator()(const rectangle& r)
{
return drectangle(left(r.left()),
top(r.top()),
right(r.right()),
bottom(r.bottom()));
}
const momentum_filter& get_left () const { return left; }
momentum_filter& get_left () { return left; }
const momentum_filter& get_top () const { return top; }
momentum_filter& get_top () { return top; }
const momentum_filter& get_right () const { return right; }
momentum_filter& get_right () { return right; }
const momentum_filter& get_bottom () const { return bottom; }
momentum_filter& get_bottom () { return bottom; }
friend void serialize(const rect_filter& item, std::ostream& out)
{
int version = 123;
serialize(version, out);
serialize(item.left, out);
serialize(item.top, out);
serialize(item.right, out);
serialize(item.bottom, out);
}
friend void deserialize(rect_filter& item, std::istream& in)
{
int version = 0;
deserialize(version, in);
if (version != 123)
throw dlib::serialization_error("Unknown version number found while deserializing rect_filter object.");
deserialize(item.left, in);
deserialize(item.top, in);
deserialize(item.right, in);
deserialize(item.bottom, in);
}
private:
momentum_filter left, top, right, bottom;
};
// ----------------------------------------------------------------------------------------
rect_filter find_optimal_rect_filter (
const std::vector<rectangle>& rects,
const double smoothness = 1
);
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_KALMAN_FiLTER_Hh_