Vowpal Wabbit
active.cc
Go to the documentation of this file.
1 #include <cerrno>
2 #include "reductions.h"
3 #include "rand48.h"
4 #include <cfloat>
5 #include "vw.h"
6 #include "active.h"
7 #include "vw_exception.h"
8 
9 using namespace LEARNER;
10 using namespace VW::config;
11 
12 float get_active_coin_bias(float k, float avg_loss, float g, float c0)
13 {
14  float b, sb, rs, sl;
15  b = (float)(c0 * (log(k + 1.) + 0.0001) / (k + 0.0001));
16  sb = std::sqrt(b);
17  avg_loss = std::min(1.f, std::max(0.f, avg_loss)); // loss should be in [0,1]
18 
19  sl = std::sqrt(avg_loss) + std::sqrt(avg_loss + g);
20  if (g <= sb * sl + b)
21  return 1;
22  rs = (sl + std::sqrt(sl * sl + 4 * g)) / (2 * g);
23  return b * rs * rs;
24 }
25 
26 float query_decision(active& a, float ec_revert_weight, float k)
27 {
28  float bias, avg_loss, weighted_queries;
29  if (k <= 1.)
30  bias = 1.;
31  else
32  {
33  weighted_queries = (float)a.all->sd->weighted_labeled_examples;
34  avg_loss = (float)(a.all->sd->sum_loss / k + std::sqrt((1. + 0.5 * log(k)) / (weighted_queries + 0.0001)));
35  bias = get_active_coin_bias(k, avg_loss, ec_revert_weight / k, a.active_c0);
36  }
37  if (a._random_state->get_and_update_random() < bias)
38  return 1.f / bias;
39  else
40  return -1.;
41 }
42 
43 template <bool is_learn>
45 {
46  base.predict(ec);
47 
48  if (is_learn)
49  {
50  vw& all = *a.all;
51 
52  float k = (float)all.sd->t;
53  float threshold = 0.f;
54 
55  ec.confidence = fabsf(ec.pred.scalar - threshold) / base.sensitivity(ec);
56  float importance = query_decision(a, ec.confidence, k);
57 
58  if (importance > 0)
59  {
60  all.sd->queries += 1;
61  ec.weight *= importance;
62  base.learn(ec);
63  }
64  else
65  {
66  ec.l.simple.label = FLT_MAX;
67  ec.weight = 0.f;
68  }
69  }
70 }
71 
72 template <bool is_learn>
74 {
75  if (is_learn)
76  base.learn(ec);
77  else
78  base.predict(ec);
79 
80  if (ec.l.simple.label == FLT_MAX)
81  {
82  float threshold = (a.all->sd->max_label + a.all->sd->min_label) * 0.5f;
83  ec.confidence = fabsf(ec.pred.scalar - threshold) / base.sensitivity(ec);
84  }
85 }
86 
87 void active_print_result(int f, float res, float weight, v_array<char> tag)
88 {
89  if (f >= 0)
90  {
91  std::stringstream ss;
92  ss << std::fixed << res;
93  if (!print_tag(ss, tag))
94  ss << ' ';
95  if (weight >= 0)
96  ss << " " << std::fixed << weight;
97  ss << '\n';
98  ssize_t len = ss.str().size();
99  ssize_t t = io_buf::write_file_or_socket(f, ss.str().c_str(), (unsigned int)len);
100  if (t != len)
101  std::cerr << "write error: " << strerror(errno) << std::endl;
102  }
103 }
104 
106 {
107  label_data& ld = ec.l.simple;
108 
109  all.sd->update(ec.test_only, ld.label != FLT_MAX, ec.loss, ec.weight, ec.num_features);
110  if (ld.label != FLT_MAX && !ec.test_only)
111  all.sd->weighted_labels += ((double)ld.label) * ec.weight;
112  all.sd->weighted_unlabeled_examples += ld.label == FLT_MAX ? ec.weight : 0;
113 
114  float ai = -1;
115  if (ld.label == FLT_MAX)
116  ai = query_decision(a, ec.confidence, (float)all.sd->weighted_unlabeled_examples);
117 
118  all.print(all.raw_prediction, ec.partial_prediction, -1, ec.tag);
119  for (auto i : all.final_prediction_sink)
120  {
121  active_print_result(i, ec.pred.scalar, ai, ec.tag);
122  }
123 
124  print_update(all, ec);
125 }
126 
128 {
129  output_and_account_example(all, a, ec);
130  VW::finish_example(all, ec);
131 }
132 
134 {
135  auto data = scoped_calloc_or_throw<active>();
136 
137  bool active_option = false;
138  bool simulation = false;
139  option_group_definition new_options("Active Learning");
140  new_options.add(make_option("active", active_option).keep().help("enable active learning"))
141  .add(make_option("simulation", simulation).help("active learning simulation mode"))
142  .add(make_option("mellowness", data->active_c0)
143  .default_value(8.f)
144  .help("active learning mellowness parameter c_0. Default 8"));
145  options.add_and_parse(new_options);
146 
147  if (!active_option)
148  return nullptr;
149 
150  data->all = &all;
151  data->_random_state = all.get_random_state();
152 
153  if (options.was_supplied("lda"))
154  THROW("error: you can't combine lda and active learning");
155 
156  auto base = as_singleline(setup_base(options, all));
157 
158  // Create new learner
160  if (options.was_supplied("simulation"))
161  l = &init_learner(data, base, predict_or_learn_simulation<true>, predict_or_learn_simulation<false>);
162  else
163  {
164  all.active = true;
165  l = &init_learner(data, base, predict_or_learn_active<true>, predict_or_learn_active<false>);
167  }
168 
169  return make_base(*l);
170 }
double sum_loss
Definition: global_data.h:145
v_array< char > tag
Definition: example.h:63
int raw_prediction
Definition: global_data.h:519
std::shared_ptr< rand_state > _random_state
Definition: active.h:10
vw * all
Definition: active.h:9
void predict(E &ec, size_t i=0)
Definition: learner.h:169
Definition: active.h:6
void return_active_example(vw &all, active &a, example &ec)
Definition: active.cc:127
float scalar
Definition: example.h:45
double weighted_unlabeled_examples
Definition: global_data.h:143
static ssize_t write_file_or_socket(int f, const void *buf, size_t nbytes)
Definition: io_buf.cc:140
void output_and_account_example(vw &all, active &a, example &ec)
Definition: active.cc:105
float query_decision(active &a, float ec_revert_weight, float k)
Definition: active.cc:26
v_array< int > final_prediction_sink
Definition: global_data.h:518
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
float confidence
Definition: example.h:72
float partial_prediction
Definition: example.h:68
int print_tag(std::stringstream &ss, v_array< char > tag)
Definition: global_data.cc:81
virtual void add_and_parse(const option_group_definition &group)=0
float label
Definition: simple_label.h:14
label_data simple
Definition: example.h:28
void active_print_result(int f, float res, float weight, v_array< char > tag)
Definition: active.cc:87
void predict_or_learn_simulation(active &a, single_learner &base, example &ec)
Definition: active.cc:44
std::shared_ptr< rand_state > get_random_state()
Definition: global_data.h:553
single_learner * as_singleline(learner< T, E > *l)
Definition: learner.h:476
void print_update(vw &all, bool is_test, example &ec, multi_ex *ec_seq, bool action_scores)
Definition: cb.cc:180
base_learner * active_setup(options_i &options, vw &all)
Definition: active.cc:133
void set_finish_example(void(*f)(vw &all, T &, E &))
Definition: learner.h:307
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
shared_data * sd
Definition: global_data.h:375
float get_active_coin_bias(float k, float avg_loss, float g, float c0)
Definition: active.cc:12
bool active
Definition: global_data.h:489
size_t num_features
Definition: example.h:67
virtual bool was_supplied(const std::string &key)=0
double weighted_labels
Definition: global_data.h:144
float active_c0
Definition: active.h:8
void finish_example(vw &, example &)
Definition: parser.cc:881
void update(bool test_example, bool labeled_example, float loss, float weight, size_t num_features)
Definition: global_data.h:190
float loss
Definition: example.h:70
float sensitivity(example &ec, size_t i=0)
Definition: learner.h:242
float weight
option_group_definition & add(T &&op)
Definition: options.h:90
polylabel l
Definition: example.h:57
constexpr uint64_t a
Definition: rand48.cc:11
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
float min_label
Definition: global_data.h:150
float max_label
Definition: global_data.h:151
double weighted_labeled_examples
Definition: global_data.h:141
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
void predict_or_learn_active(active &a, single_learner &base, example &ec)
Definition: active.cc:73
polyprediction pred
Definition: example.h:60
void learn(E &ec, size_t i=0)
Definition: learner.h:160
float weight
Definition: example.h:62
constexpr int bias
Definition: rand48.cc:14
#define THROW(args)
Definition: vw_exception.h:181
size_t queries
Definition: global_data.h:135
void(* print)(int, float, float, v_array< char >)
Definition: global_data.h:521
float f
Definition: cache.cc:40
bool test_only
Definition: example.h:76