87 static const float alpha = 2.0f;
89 return std::max(0.
f, std::min(1.
f, 0.5
f * (1.0
f + alpha * x)));
102 b.
nodes[root].base_router = routers_used++;
104 b.
nodes[root].internal =
true;
105 b.
nodes[root].left = left_child;
106 b.
nodes[left_child].parent = root;
107 b.
nodes[left_child].depth = depth;
108 b.
nodes[root].right = right_child;
109 b.
nodes[right_child].parent = root;
110 b.
nodes[right_child].depth = depth;
112 init_tree(b, left_child, depth + 1, routers_used);
113 init_tree(b, right_child, depth + 1, routers_used);
119 uint32_t routers_used = 0;
155 double mass_at_k = 0;
159 mass_at_k += ls->label_count;
162 float f = (float)mass_at_k / (
float)n->
n;
163 float stdf = std::sqrt(f * (1.f - f) / (
float)n->
n);
164 float diamf = 15.f / (std::sqrt(18.f) * (float)n->
n);
169 double plogp(
double c,
double n) {
return (c == 0) ? 0 : (c / n) * log(c / n); }
183 double deltac0 = ec.
weight;
184 double n = b.
nodes[cn].n;
186 double novernp1 = n / (deltac0 + n);
187 double lognovernp1 = (novernp1 == 0) ? 0 : log(novernp1);
188 double nminusc0overnp1 = (n - c0) / (n + deltac0);
190 double newentropy = b.
nodes[cn].entropy;
192 newentropy +=
plogp(c0, n);
193 newentropy *= novernp1;
194 newentropy -= lognovernp1 * nminusc0overnp1;
195 newentropy -=
plogp(c0 + deltac0, n + deltac0);
210 std::swap(ls[-1], ls[0]);
232 fs.
push_back(1., (((uint64_t)868771 * cn) << ss) & mask);
238 fs.
push_back(1., (((uint64_t)868771 * cn) << ss) & mask);
239 cn = b.
nodes[cn].parent;
259 uint32_t amaxscore = 0;
262 ec.
l.
simple = {FLT_MAX, 0.f, 0.f};
264 float maxscore = std::numeric_limits<float>::lowest();
272 amaxscore = ls->label;
316 ec.
l.
simple = {FLT_MAX, 0.f, 0.f};
317 while (b.
nodes[cn].internal)
351 double old_left = b.
nodes[b.
nodes[cn].left].entropy;
352 double old_right = b.
nodes[b.
nodes[cn].right].entropy;
355 double delta_left = nl * (new_left - old_left) + mc.
weight * new_left;
356 double delta_right = nr * (new_right - old_right) + mc.
weight * new_right;
357 float route_label = delta_left < delta_right ? -1.f : 1.f;
358 float imp_weight = fabs((
float)(delta_left - delta_right));
360 ec.
l.
simple = {route_label, imp_weight, 0.};
384 while (b.
nodes[cn].internal)
404 if (!b.
nodes[cn].internal)
416 ec.
l.
simple = {-1.f, 1.f, 0.f};
421 if (ls->label != mc.
label)
437 std::stringstream msg;
446 for (uint32_t j = 0; j < n_nodes; ++j)
455 for (uint32_t j = 0; j < n_nodes; ++j)
476 for (uint32_t k = 0; k < n_preds; ++k)
482 for (uint32_t k = 0; k < n_preds; ++k)
504 auto tree = scoped_calloc_or_throw<recall_tree>();
506 new_options.add(
make_option(
"recall_tree", tree->k).keep().help(
"Use online tree for multiclass"))
509 .help(
"maximum number of labels per leaf in the tree"))
510 .
add(
make_option(
"bern_hyper", tree->bern_hyper).default_value(1.
f).help(
"recall tree depth penalty"))
511 .
add(
make_option(
"max_depth", tree->max_depth).keep().help(
"maximum depth of the tree, default log_2 (#classes)"))
512 .
add(
make_option(
"node_only", tree->node_only).keep().help(
"only use node features, not full path features"))
513 .
add(
make_option(
"randomized_routing", tree->randomized_routing).keep().help(
"randomized routing"));
521 tree->max_candidates = options.
was_supplied(
"max_candidates")
522 ? tree->max_candidates
523 : std::min(tree->k, 4 * (uint32_t)(ceil(log(tree->k) / log(2.0))));
525 options.
was_supplied(
"max_depth") ? tree->max_depth : (uint32_t)std::ceil(std::log(tree->k) / std::log(2.0));
531 <<
" node_only = " << tree->node_only <<
" bern_hyper = " << tree->bern_hyper
532 <<
" max_depth = " << tree->max_depth <<
" routing = " 533 << (all.
training ? (tree->randomized_routing ?
"randomized" :
"deterministic") :
"n/a testonly")
v_array< namespace_index > indices
std::shared_ptr< rand_state > _random_state
void predict(E &ec, size_t i=0)
uint32_t oas_predict(recall_tree &b, single_learner &base, uint32_t cn, example &ec)
void push_back(feature_value v, feature_index i)
bool stop_recurse_check(recall_tree &b, uint32_t parent, uint32_t child)
the core definition of a set of features.
base_learner * make_base(learner< T, E > &base)
void predict(recall_tree &b, single_learner &base, example &ec)
predict_type(uint32_t a, uint32_t b)
virtual void add_and_parse(const option_group_definition &group)=0
void set_save_load(void(*sl)(T &, io_buf &, bool, bool))
node_pred * find_or_create(recall_tree &b, uint32_t cn, example &ec)
void add_node_id_feature(recall_tree &b, uint32_t cn, example &ec)
void insert_example_at_node(recall_tree &b, uint32_t cn, example &ec)
#define writeit(what, str)
void train_node(log_multi &b, single_learner &base, example &ec, uint32_t ¤t, uint32_t &class_index, uint32_t)
std::shared_ptr< rand_state > get_random_state()
std::array< features, NUM_NAMESPACES > feature_space
single_learner * as_singleline(learner< T, E > *l)
MULTICLASS::label_t multi
void learn(recall_tree &b, single_learner &base, example &ec)
void compute_recall_lbest(recall_tree &b, node *n)
void push_back(const T &new_ele)
virtual bool was_supplied(const std::string &key)=0
base_learner * recall_tree_setup(options_i &options, vw &all)
bool is_candidate(recall_tree &b, uint32_t cn, example &ec)
double plogp(double c, double n)
void remove_node_id_feature(recall_tree &, uint32_t, example &ec)
int add(svm_params ¶ms, svm_example *fec)
predict_type predict_from(recall_tree &b, single_learner &base, example &ec, uint32_t cn)
node_pred * find(recall_tree &b, uint32_t cn, example &ec)
void save_load_tree(log_multi &b, io_buf &model_file, bool read, bool text)
double updated_entropy(recall_tree &b, uint32_t cn, example &ec)
typed_option< T > make_option(std::string name, T &location)
void init_tree(log_multi &d)
learner< T, E > & init_multiclass_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), parser *p, size_t ws, prediction_type::prediction_type_t pred_type=prediction_type::multiclass)
LEARNER::base_learner * setup_base(options_i &options, vw &all)
void learn(E &ec, size_t i=0)
constexpr unsigned char node_id_namespace
v_array< node_pred > preds
uint32_t descend(node &n, float prediction)
uint32_t class_prediction
#define writeitvar(what, str, mywhat)