Vowpal Wabbit
Functions
active.cc File Reference
#include <cerrno>
#include "reductions.h"
#include "rand48.h"
#include <cfloat>
#include "vw.h"
#include "active.h"
#include "vw_exception.h"

Go to the source code of this file.

Functions

float get_active_coin_bias (float k, float avg_loss, float g, float c0)
 
float query_decision (active &a, float ec_revert_weight, float k)
 
template<bool is_learn>
void predict_or_learn_simulation (active &a, single_learner &base, example &ec)
 
template<bool is_learn>
void predict_or_learn_active (active &a, single_learner &base, example &ec)
 
void active_print_result (int f, float res, float weight, v_array< char > tag)
 
void output_and_account_example (vw &all, active &a, example &ec)
 
void return_active_example (vw &all, active &a, example &ec)
 
base_learneractive_setup (options_i &options, vw &all)
 

Function Documentation

◆ active_print_result()

void active_print_result ( int  f,
float  res,
float  weight,
v_array< char >  tag 
)

Definition at line 87 of file active.cc.

References print_tag(), and io_buf::write_file_or_socket().

Referenced by output_and_account_example().

88 {
89  if (f >= 0)
90  {
91  std::stringstream ss;
92  ss << std::fixed << res;
93  if (!print_tag(ss, tag))
94  ss << ' ';
95  if (weight >= 0)
96  ss << " " << std::fixed << weight;
97  ss << '\n';
98  ssize_t len = ss.str().size();
99  ssize_t t = io_buf::write_file_or_socket(f, ss.str().c_str(), (unsigned int)len);
100  if (t != len)
101  std::cerr << "write error: " << strerror(errno) << std::endl;
102  }
103 }
static ssize_t write_file_or_socket(int f, const void *buf, size_t nbytes)
Definition: io_buf.cc:140
int print_tag(std::stringstream &ss, v_array< char > tag)
Definition: global_data.cc:81
float weight
float f
Definition: cache.cc:40

◆ active_setup()

base_learner* active_setup ( options_i options,
vw all 
)

Definition at line 133 of file active.cc.

References vw::active, VW::config::option_group_definition::add(), VW::config::options_i::add_and_parse(), LEARNER::as_singleline(), vw::get_random_state(), LEARNER::init_learner(), LEARNER::make_base(), VW::config::make_option(), return_active_example(), LEARNER::learner< T, E >::set_finish_example(), setup_base(), THROW, and VW::config::options_i::was_supplied().

Referenced by parse_reductions().

134 {
135  auto data = scoped_calloc_or_throw<active>();
136 
137  bool active_option = false;
138  bool simulation = false;
139  option_group_definition new_options("Active Learning");
140  new_options.add(make_option("active", active_option).keep().help("enable active learning"))
141  .add(make_option("simulation", simulation).help("active learning simulation mode"))
142  .add(make_option("mellowness", data->active_c0)
143  .default_value(8.f)
144  .help("active learning mellowness parameter c_0. Default 8"));
145  options.add_and_parse(new_options);
146 
147  if (!active_option)
148  return nullptr;
149 
150  data->all = &all;
151  data->_random_state = all.get_random_state();
152 
153  if (options.was_supplied("lda"))
154  THROW("error: you can't combine lda and active learning");
155 
156  auto base = as_singleline(setup_base(options, all));
157 
158  // Create new learner
160  if (options.was_supplied("simulation"))
161  l = &init_learner(data, base, predict_or_learn_simulation<true>, predict_or_learn_simulation<false>);
162  else
163  {
164  all.active = true;
165  l = &init_learner(data, base, predict_or_learn_active<true>, predict_or_learn_active<false>);
167  }
168 
169  return make_base(*l);
170 }
void return_active_example(vw &all, active &a, example &ec)
Definition: active.cc:127
base_learner * make_base(learner< T, E > &base)
Definition: learner.h:462
virtual void add_and_parse(const option_group_definition &group)=0
std::shared_ptr< rand_state > get_random_state()
Definition: global_data.h:553
single_learner * as_singleline(learner< T, E > *l)
Definition: learner.h:476
void set_finish_example(void(*f)(vw &all, T &, E &))
Definition: learner.h:307
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
Definition: learner.h:369
bool active
Definition: global_data.h:489
virtual bool was_supplied(const std::string &key)=0
typed_option< T > make_option(std::string name, T &location)
Definition: options.h:80
LEARNER::base_learner * setup_base(options_i &options, vw &all)
Definition: parse_args.cc:1222
#define THROW(args)
Definition: vw_exception.h:181

◆ get_active_coin_bias()

float get_active_coin_bias ( float  k,
float  avg_loss,
float  g,
float  c0 
)

Definition at line 12 of file active.cc.

References f.

Referenced by query_decision().

13 {
14  float b, sb, rs, sl;
15  b = (float)(c0 * (log(k + 1.) + 0.0001) / (k + 0.0001));
16  sb = std::sqrt(b);
17  avg_loss = std::min(1.f, std::max(0.f, avg_loss)); // loss should be in [0,1]
18 
19  sl = std::sqrt(avg_loss) + std::sqrt(avg_loss + g);
20  if (g <= sb * sl + b)
21  return 1;
22  rs = (sl + std::sqrt(sl * sl + 4 * g)) / (2 * g);
23  return b * rs * rs;
24 }
float f
Definition: cache.cc:40

◆ output_and_account_example()

void output_and_account_example ( vw all,
active a,
example ec 
)

Definition at line 105 of file active.cc.

References active_print_result(), example::confidence, vw::final_prediction_sink, example::l, label_data::label, example::loss, example::num_features, example::partial_prediction, example::pred, vw::print, CB::print_update(), query_decision(), vw::raw_prediction, polyprediction::scalar, vw::sd, polylabel::simple, example::tag, example::test_only, shared_data::update(), example::weight, shared_data::weighted_labels, and shared_data::weighted_unlabeled_examples.

Referenced by finish_example(), keep_example(), return_active_example(), return_example(), and no_label::return_no_label_example().

106 {
107  label_data& ld = ec.l.simple;
108 
109  all.sd->update(ec.test_only, ld.label != FLT_MAX, ec.loss, ec.weight, ec.num_features);
110  if (ld.label != FLT_MAX && !ec.test_only)
111  all.sd->weighted_labels += ((double)ld.label) * ec.weight;
112  all.sd->weighted_unlabeled_examples += ld.label == FLT_MAX ? ec.weight : 0;
113 
114  float ai = -1;
115  if (ld.label == FLT_MAX)
116  ai = query_decision(a, ec.confidence, (float)all.sd->weighted_unlabeled_examples);
117 
118  all.print(all.raw_prediction, ec.partial_prediction, -1, ec.tag);
119  for (auto i : all.final_prediction_sink)
120  {
121  active_print_result(i, ec.pred.scalar, ai, ec.tag);
122  }
123 
124  print_update(all, ec);
125 }
v_array< char > tag
Definition: example.h:63
int raw_prediction
Definition: global_data.h:519
float scalar
Definition: example.h:45
double weighted_unlabeled_examples
Definition: global_data.h:143
float query_decision(active &a, float ec_revert_weight, float k)
Definition: active.cc:26
v_array< int > final_prediction_sink
Definition: global_data.h:518
float confidence
Definition: example.h:72
float partial_prediction
Definition: example.h:68
float label
Definition: simple_label.h:14
label_data simple
Definition: example.h:28
void active_print_result(int f, float res, float weight, v_array< char > tag)
Definition: active.cc:87
void print_update(vw &all, bool is_test, example &ec, multi_ex *ec_seq, bool action_scores)
Definition: cb.cc:180
shared_data * sd
Definition: global_data.h:375
size_t num_features
Definition: example.h:67
double weighted_labels
Definition: global_data.h:144
void update(bool test_example, bool labeled_example, float loss, float weight, size_t num_features)
Definition: global_data.h:190
float loss
Definition: example.h:70
polylabel l
Definition: example.h:57
polyprediction pred
Definition: example.h:60
float weight
Definition: example.h:62
void(* print)(int, float, float, v_array< char >)
Definition: global_data.h:521
bool test_only
Definition: example.h:76

◆ predict_or_learn_active()

template<bool is_learn>
void predict_or_learn_active ( active a,
single_learner base,
example ec 
)

Definition at line 73 of file active.cc.

References active::all, example::confidence, f, example::l, label_data::label, LEARNER::learner< T, E >::learn(), shared_data::max_label, shared_data::min_label, example::pred, LEARNER::learner< T, E >::predict(), polyprediction::scalar, vw::sd, LEARNER::learner< T, E >::sensitivity(), and polylabel::simple.

74 {
75  if (is_learn)
76  base.learn(ec);
77  else
78  base.predict(ec);
79 
80  if (ec.l.simple.label == FLT_MAX)
81  {
82  float threshold = (a.all->sd->max_label + a.all->sd->min_label) * 0.5f;
83  ec.confidence = fabsf(ec.pred.scalar - threshold) / base.sensitivity(ec);
84  }
85 }
vw * all
Definition: active.h:9
void predict(E &ec, size_t i=0)
Definition: learner.h:169
float scalar
Definition: example.h:45
float confidence
Definition: example.h:72
float label
Definition: simple_label.h:14
label_data simple
Definition: example.h:28
shared_data * sd
Definition: global_data.h:375
float sensitivity(example &ec, size_t i=0)
Definition: learner.h:242
polylabel l
Definition: example.h:57
float min_label
Definition: global_data.h:150
float max_label
Definition: global_data.h:151
polyprediction pred
Definition: example.h:60
void learn(E &ec, size_t i=0)
Definition: learner.h:160
float f
Definition: cache.cc:40

◆ predict_or_learn_simulation()

template<bool is_learn>
void predict_or_learn_simulation ( active a,
single_learner base,
example ec 
)

Definition at line 44 of file active.cc.

References active::all, example::confidence, example::l, label_data::label, LEARNER::learner< T, E >::learn(), example::pred, LEARNER::learner< T, E >::predict(), shared_data::queries, query_decision(), polyprediction::scalar, vw::sd, LEARNER::learner< T, E >::sensitivity(), polylabel::simple, shared_data::t, and example::weight.

45 {
46  base.predict(ec);
47 
48  if (is_learn)
49  {
50  vw& all = *a.all;
51 
52  float k = (float)all.sd->t;
53  float threshold = 0.f;
54 
55  ec.confidence = fabsf(ec.pred.scalar - threshold) / base.sensitivity(ec);
56  float importance = query_decision(a, ec.confidence, k);
57 
58  if (importance > 0)
59  {
60  all.sd->queries += 1;
61  ec.weight *= importance;
62  base.learn(ec);
63  }
64  else
65  {
66  ec.l.simple.label = FLT_MAX;
67  ec.weight = 0.f;
68  }
69  }
70 }
vw * all
Definition: active.h:9
void predict(E &ec, size_t i=0)
Definition: learner.h:169
float scalar
Definition: example.h:45
float query_decision(active &a, float ec_revert_weight, float k)
Definition: active.cc:26
float confidence
Definition: example.h:72
float label
Definition: simple_label.h:14
label_data simple
Definition: example.h:28
shared_data * sd
Definition: global_data.h:375
float sensitivity(example &ec, size_t i=0)
Definition: learner.h:242
polylabel l
Definition: example.h:57
polyprediction pred
Definition: example.h:60
void learn(E &ec, size_t i=0)
Definition: learner.h:160
float weight
Definition: example.h:62
size_t queries
Definition: global_data.h:135

◆ query_decision()

float query_decision ( active a,
float  ec_revert_weight,
float  k 
)

Definition at line 26 of file active.cc.

References active::_random_state, active::active_c0, active::all, bias, f, get_active_coin_bias(), vw::sd, shared_data::sum_loss, and shared_data::weighted_labeled_examples.

Referenced by output_and_account_example(), and predict_or_learn_simulation().

27 {
28  float bias, avg_loss, weighted_queries;
29  if (k <= 1.)
30  bias = 1.;
31  else
32  {
33  weighted_queries = (float)a.all->sd->weighted_labeled_examples;
34  avg_loss = (float)(a.all->sd->sum_loss / k + std::sqrt((1. + 0.5 * log(k)) / (weighted_queries + 0.0001)));
35  bias = get_active_coin_bias(k, avg_loss, ec_revert_weight / k, a.active_c0);
36  }
37  if (a._random_state->get_and_update_random() < bias)
38  return 1.f / bias;
39  else
40  return -1.;
41 }
double sum_loss
Definition: global_data.h:145
std::shared_ptr< rand_state > _random_state
Definition: active.h:10
vw * all
Definition: active.h:9
shared_data * sd
Definition: global_data.h:375
float get_active_coin_bias(float k, float avg_loss, float g, float c0)
Definition: active.cc:12
float active_c0
Definition: active.h:8
double weighted_labeled_examples
Definition: global_data.h:141
constexpr int bias
Definition: rand48.cc:14
float f
Definition: cache.cc:40

◆ return_active_example()

void return_active_example ( vw all,
active a,
example ec 
)

Definition at line 127 of file active.cc.

References VW::finish_example(), and output_and_account_example().

Referenced by active_setup().

128 {
129  output_and_account_example(all, a, ec);
130  VW::finish_example(all, ec);
131 }
void output_and_account_example(vw &all, active &a, example &ec)
Definition: active.cc:105
void finish_example(vw &, example &)
Definition: parser.cc:881