Vowpal Wabbit
loss_functions.h
Go to the documentation of this file.
1 /*
2 Copyright (c) by respective owners including Yahoo!, Microsoft, and
3 individual contributors. All rights reserved. Released under a BSD
4 license as described in the file LICENSE.
5  */
6 #pragma once
7 #include <string>
8 #include "parse_primitives.h"
9 
10 struct shared_data;
11 struct vw;
12 
14 {
15  public:
16  // Identifies the type of the implementing loss function, matches the name used in getLossFunction.
17  virtual std::string getType() = 0;
18 
19  /*
20  * getLoss evaluates the example loss.
21  * The function returns the loss value
22  */
23  // virtual float getLoss(example *&ec, gd_vars &vars) = 0;
24  virtual float getLoss(shared_data*, float prediction, float label) = 0;
25 
26  /*
27  * getUpdate evaluates the update scalar
28  * The function return the update scalar
29  */
30  virtual float getUpdate(float prediction, float label, float update_scale, float pred_per_update) = 0;
31  virtual float getUnsafeUpdate(float prediction, float label, float eta_t) = 0;
32 
33  // the number of examples of the opposite label such that updating with
34  // that number results in the opposite label.
35  // 0 = prediction + pred_per_update
36  // * getUpdate(prediction, opposite, pred_per_update*getRevertingWeight(), pred_per_update)
37  virtual float getRevertingWeight(shared_data*, float prediction, float eta_t) = 0;
38  virtual float getSquareGrad(float prediction, float label) = 0;
39  virtual float first_derivative(shared_data*, float prediction, float label) = 0;
40  virtual float second_derivative(shared_data*, float prediction, float label) = 0;
41  virtual ~loss_function(){};
42 };
43 
44 loss_function* getLossFunction(vw&, std::string funcName, float function_parameter = 0);
virtual float getUpdate(float prediction, float label, float update_scale, float pred_per_update)=0
virtual float second_derivative(shared_data *, float prediction, float label)=0
virtual float getRevertingWeight(shared_data *, float prediction, float eta_t)=0
virtual float first_derivative(shared_data *, float prediction, float label)=0
virtual float getLoss(shared_data *, float prediction, float label)=0
virtual std::string getType()=0
virtual ~loss_function()
loss_function * getLossFunction(vw &, std::string funcName, float function_parameter=0)
virtual float getUnsafeUpdate(float prediction, float label, float eta_t)=0
virtual float getSquareGrad(float prediction, float label)=0