13 #define W_XT 0 // current parameter 14 #define W_ZT 1 // in proximal is "accumulated z(t) = z(t-1) + g(t) + sigma*w(t)", in general is the dual weight vector 15 #define W_G2 2 // accumulated gradient information 16 #define W_MX 3 // maximum absolute value 17 #define W_WE 4 // Wealth 18 #define W_MG 5 // maximum gradient 67 float sqrtf_ng2 = sqrtf(w[
W_G2]);
74 GD::foreach_feature<uncertainty, predict_with_confidence>(*(b.
all), ec, uncetain);
75 return uncetain.
score;
106 if (finalize_predictions)
110 for (
size_t c = 0; c < count; c++)
116 ec.
ft_offset -= (uint64_t)(step * count);
123 float gradient = d.
update * x;
124 float ng2 = w[
W_G2] + gradient * gradient;
125 float sqrt_ng2 = sqrtf(ng2);
126 float sqrt_wW_G2 = sqrtf(w[
W_G2]);
127 float sigma = (sqrt_ng2 - sqrt_wW_G2) / d.
ftrl_alpha;
128 w[
W_ZT] += gradient - sigma * w[
W_XT];
130 sqrt_wW_G2 = sqrt_ng2;
132 float fabs_zt = w[
W_ZT] * flag;
146 float fabs_x = fabs(x);
147 if (fabs_x > w[
W_MX])
150 float squared_theta = w[
W_ZT] * w[
W_ZT];
160 float gradient = d.
update * x;
162 w[
W_ZT] += -gradient;
163 w[
W_G2] += fabs(gradient);
176 float w_mx = w[
W_MX];
179 float fabs_x = fabs(x);
186 if (w[
W_MG] * w_mx > 0)
197 float fabs_x = fabs(x);
198 float gradient = d.
update * x;
200 if (fabs_x > w[
W_MX])
205 float fabs_gradient = fabs(d.
update);
206 if (fabs_gradient > w[
W_MG])
212 if (w[W_MG] * w[W_MX] > 0)
217 w[
W_ZT] += -gradient;
218 w[
W_G2] += fabs(gradient);
227 GD::foreach_feature<update_data, inner_update_cb_state_and_predict>(*b.
all, ec, b.
data);
241 GD::foreach_feature<update_data, inner_update_pistol_state_and_predict>(*b.
all, ec, b.
data);
250 GD::foreach_feature<update_data, inner_update_proximal>(*b.
all, ec, b.
data);
257 GD::foreach_feature<update_data, inner_update_pistol_post>(*b.
all, ec, b.
data);
264 GD::foreach_feature<update_data, inner_update_cb_post>(*b.
all, ec, b.
data);
267 template <
bool audit>
273 predict<audit>(
a, base, ec);
310 std::stringstream msg;
311 msg <<
":" << resume <<
"\n";
337 auto b = scoped_calloc_or_throw<ftrl>();
338 bool ftrl_option =
false;
343 new_options.
add(
make_option(
"ftrl", ftrl_option).keep().help(
"FTRL: Follow the Proximal Regularized Leader"))
344 .
add(
make_option(
"coin", coin).keep().help(
"Coin betting optimizer"))
345 .
add(
make_option(
"pistol", pistol).keep().help(
"PiSTOL: Parameter-free STOchastic Learning"))
346 .
add(
make_option(
"ftrl_alpha", b->ftrl_alpha).help(
"Learning rate for FTRL optimization"))
347 .
add(
make_option(
"ftrl_beta", b->ftrl_beta).help(
"Learning rate for FTRL optimization"));
350 if (!ftrl_option && !pistol && !coin)
358 b->ftrl_alpha = options.
was_supplied(
"ftrl_alpha") ? b->ftrl_alpha : 0.005f;
359 b->ftrl_beta = options.
was_supplied(
"ftrl_beta") ? b->ftrl_beta : 0.1f;
363 b->ftrl_alpha = options.
was_supplied(
"ftrl_alpha") ? b->ftrl_alpha : 1.0f;
364 b->ftrl_beta = options.
was_supplied(
"ftrl_beta") ? b->ftrl_beta : 0.5f;
368 b->ftrl_alpha = options.
was_supplied(
"ftrl_alpha") ? b->ftrl_alpha : 4.0f;
369 b->ftrl_beta = options.
was_supplied(
"ftrl_beta") ? b->ftrl_beta : 1.0f;
373 b->no_win_counter = 0;
379 std::string algorithm_name;
382 algorithm_name =
"Proximal-FTRL";
384 learn_ptr = learn_proximal<true>;
386 learn_ptr = learn_proximal<false>;
392 algorithm_name =
"PiSTOL";
399 algorithm_name =
"Coin Betting";
405 b->data.ftrl_alpha = b->ftrl_alpha;
406 b->data.ftrl_beta = b->ftrl_beta;
407 b->data.l1_lambda = b->all->l1_lambda;
408 b->data.l2_lambda = b->all->l2_lambda;
412 std::cerr <<
"Enabling FTRL based optimization" << std::endl;
413 std::cerr <<
"Algorithm used: " << algorithm_name << std::endl;
414 std::cerr <<
"ftrl_alpha = " << b->ftrl_alpha << std::endl;
415 std::cerr <<
"ftrl_beta = " << b->ftrl_beta << std::endl;
421 b->early_stop_thres = options.
get_typed_option<
size_t>(
"early_terminate").value();
431 l->set_multipredict(multipredict<true>);
433 l->set_multipredict(multipredict<false>);
void update_after_prediction_pistol(ftrl &b, example &ec)
void update_after_prediction_cb(ftrl &b, example &ec)
void update_state_and_predict_pistol(ftrl &b, single_learner &, example &ec)
float finalize_prediction(shared_data *sd, float ret)
void print_audit_features(vw &all, example &ec)
void initialize_regressor(vw &all, T &weights)
void predict_with_confidence(uncertainty &d, const float fx, float &fw)
void inner_update_cb_post(update_data &d, float x, float &wref)
base_learner * make_base(learner< T, E > &base)
float sensitivity(ftrl &b, base_learner &, example &ec)
void vec_add_multipredict(multipredict_info< T > &mp, const float fx, uint64_t fi)
void predict(ftrl &b, single_learner &, example &ec)
void finalize_regressor(vw &all, std::string reg_name)
virtual void add_and_parse(const option_group_definition &group)=0
size_t check_holdout_every_n_passes
bool summarize_holdout_set(vw &all, size_t &no_win_counter)
void inner_update_pistol_state_and_predict(update_data &d, float x, float &wref)
virtual float first_derivative(shared_data *, float prediction, float label)=0
float inline_predict(vw &all, example &ec)
void save_load_online_state(vw &all, io_buf &model_file, bool read, bool text, gd *g, std::stringstream &msg, uint32_t ftrl_size, T &weights)
void inner_update_cb_state_and_predict(update_data &d, float x, float &wref)
float normalized_squared_norm_x
learner< T, E > & init_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), size_t ws, prediction_type::prediction_type_t pred_type)
typed_option< T > & get_typed_option(const std::string &key)
void end_pass(example &ec, vw &all)
virtual bool was_supplied(const std::string &key)=0
base_learner * ftrl_setup(options_i &options, vw &all)
dense_parameters dense_weights
void update_after_prediction_proximal(ftrl &b, example &ec)
option_group_definition & add(T &&op)
void inner_update_pistol_post(update_data &d, float x, float &wref)
int add(svm_params ¶ms, svm_example *fec)
typed_option< T > make_option(std::string name, T &location)
constexpr uint64_t UINT64_ONE
void learn_pistol(ftrl &a, single_learner &base, example &ec)
sparse_parameters sparse_weights
void update_state_and_predict_cb(ftrl &b, single_learner &, example &ec)
void multipredict(ftrl &b, base_learner &, example &ec, size_t count, size_t step, polyprediction *pred, bool finalize_predictions)
void inner_update_proximal(update_data &d, float x, float &wref)
void save_load_regressor(vw &all, io_buf &model_file, bool read, bool text, T &weights)
std::string final_regressor_name
void learn_cb(ftrl &a, single_learner &base, example &ec)
void save_load(ftrl &b, io_buf &model_file, bool read, bool text)
size_t bin_text_read_write_fixed(io_buf &io, char *data, size_t len, const char *read_message, bool read, std::stringstream &msg, bool text)
void learn_proximal(ftrl &a, single_learner &base, example &ec)
double normalized_sum_norm_x