Vowpal Wabbit
cb_algs.h
Go to the documentation of this file.
1 /*
2 Copyright (c) by respective owners including Yahoo!, Microsoft, and
3 individual contributors. All rights reserved. Released under a BSD
4 license as described in the file LICENSE.
5  */
6 #pragma once
7 
8 #include "baseline.h"
9 
10 // TODO: extend to handle CSOAA_LDF and WAP_LDF
12 
13 #define CB_TYPE_DR 0
14 #define CB_TYPE_DM 1
15 #define CB_TYPE_IPS 2
16 #define CB_TYPE_MTR 3
17 #define CB_TYPE_SM 4
18 
19 namespace CB_ALGS
20 {
21 template <bool is_learn>
23  LEARNER::single_learner* scorer, CB::cb_class* known_cost, example& ec, uint32_t index, uint32_t base)
24 {
25  CB::label ld = ec.l.cb;
26 
27  label_data simple_temp;
28  simple_temp.initial = 0.;
29  if (known_cost != nullptr && index == known_cost->action)
30  simple_temp.label = known_cost->cost;
31  else
32  simple_temp.label = FLT_MAX;
33 
34  const bool baseline_enabled_old = BASELINE::baseline_enabled(&ec);
36  ec.l.simple = simple_temp;
37  polyprediction p = ec.pred;
38  if (is_learn && known_cost != nullptr && index == known_cost->action)
39  {
40  float old_weight = ec.weight;
41  ec.weight /= known_cost->probability;
42  scorer->learn(ec, index - 1 + base);
43  ec.weight = old_weight;
44  }
45  else
46  scorer->predict(ec, index - 1 + base);
47 
48  if (!baseline_enabled_old)
50  float pred = ec.pred.scalar;
51  ec.pred = p;
52 
53  ec.l.cb = ld;
54 
55  return pred;
56 }
57 
58 inline float get_cost_estimate(CB::cb_class* observation, uint32_t action, float offset = 0.)
59 {
60  if (action == observation->action)
61  return (observation->cost - offset) / observation->probability;
62  return 0.;
63 }
64 
65 inline float get_cost_estimate(CB::cb_class* observation, COST_SENSITIVE::label& scores, uint32_t action)
66 {
67  for (auto& cl : scores.costs)
68  if (cl.class_index == action)
69  return get_cost_estimate(observation, action, cl.x) + cl.x;
70  return get_cost_estimate(observation, action);
71 }
72 
73 inline float get_cost_estimate(ACTION_SCORE::action_score& a_s, float cost, uint32_t action, float offset = 0.)
74 {
75  if (action == a_s.action)
76  return (cost - offset) / a_s.score;
77  return 0.;
78 }
79 
80 inline bool example_is_newline_not_header(example const& ec)
81 {
82  return (example_is_newline(ec) && !CB::ec_is_example_header(ec));
83 }
84 } // namespace CB_ALGS
bool example_is_newline_not_header(example const &ec)
Definition: cb_algs.h:80
Definition: scorer.cc:8
void predict(E &ec, size_t i=0)
Definition: learner.h:169
bool ec_is_example_header(example const &ec)
Definition: cb.cc:170
float scalar
Definition: example.h:45
CB::label cb
Definition: example.h:31
void set_baseline_enabled(example *ec)
Definition: baseline.cc:23
uint32_t action
Definition: search.h:19
float label
Definition: simple_label.h:14
label_data simple
Definition: example.h:28
float get_cost_estimate(CB::cb_class *observation, uint32_t action, float offset=0.)
Definition: cb_algs.h:58
int example_is_newline(example const &ec)
Definition: example.h:104
uint32_t action
Definition: cb.h:18
float probability
Definition: cb.h:19
LEARNER::base_learner * cb_algs_setup(VW::config::options_i &options, vw &all)
Definition: cb_algs.cc:132
float get_cost_pred(LEARNER::single_learner *scorer, CB::cb_class *known_cost, example &ec, uint32_t index, uint32_t base)
Definition: cb_algs.h:22
void reset_baseline_disabled(example *ec)
Definition: baseline.cc:38
float initial
Definition: simple_label.h:16
polylabel l
Definition: example.h:57
Definition: cb.h:25
float cost
Definition: cb.h:17
polyprediction pred
Definition: example.h:60
void learn(E &ec, size_t i=0)
Definition: learner.h:160
v_array< wclass > costs
float weight
Definition: example.h:62
bool baseline_enabled(example *ec)
Definition: baseline.cc:51