116 node.
preds = v_init<node_pred>();
136 while (b.
nodes[node].internal)
138 node = b.
nodes[node].left;
140 node = b.
nodes[node].right;
149 uint32_t prev = node;
150 node = b.
nodes[node].parent;
152 if (b.
nodes[node].min_count == b.
nodes[prev].min_count)
161 for (uint32_t i = 0; i < depth; i++) std::cout <<
"\t";
164 for (
size_t i = 0; i < node.
preds.
size(); i++)
165 std::cout << node.
preds[i].label <<
":" << node.
preds[i].label_count <<
"\t";
166 std::cout << std::endl;
173 std::cout <<
"Right";
181 b.
nodes[current].preds[class_index].label_count++;
183 if (b.
nodes[current].preds[class_index].label_count > b.
nodes[current].max_count)
185 b.
nodes[current].max_count = b.
nodes[current].preds[class_index].label_count;
186 b.
nodes[current].max_count_label = b.
nodes[current].preds[class_index].label;
189 if (b.
nodes[current].internal)
191 else if (b.
nodes[current].preds.
size() > 1 &&
197 uint32_t right_child;
209 uint32_t swap_parent = b.
nodes[swap_child].parent;
210 uint32_t swap_grandparent = b.
nodes[swap_parent].parent;
211 if (b.
nodes[swap_child].min_count != b.
nodes[0].min_count)
212 std::cout <<
"glargh " << b.
nodes[swap_child].min_count <<
" != " << b.
nodes[0].min_count << std::endl;
215 uint32_t nonswap_child;
216 if (swap_child == b.
nodes[swap_parent].right)
217 nonswap_child = b.
nodes[swap_parent].left;
219 nonswap_child = b.
nodes[swap_parent].right;
221 if (swap_parent == b.
nodes[swap_grandparent].left)
222 b.
nodes[swap_grandparent].left = nonswap_child;
224 b.
nodes[swap_grandparent].right = nonswap_child;
225 b.
nodes[nonswap_child].parent = swap_grandparent;
229 left_child = swap_child;
230 b.
nodes[current].base_predictor = b.
nodes[swap_parent].base_predictor;
232 right_child = swap_parent;
234 b.
nodes[current].left = left_child;
235 b.
nodes[left_child].parent = current;
236 b.
nodes[current].right = right_child;
237 b.
nodes[right_child].parent = current;
239 b.
nodes[left_child].min_count = b.
nodes[current].min_count / 2;
240 b.
nodes[right_child].min_count = b.
nodes[current].min_count - b.
nodes[left_child].min_count;
243 b.
nodes[left_child].max_count_label = b.
nodes[current].max_count_label;
244 b.
nodes[right_child].max_count_label = b.
nodes[current].max_count_label;
246 b.
nodes[current].internal =
true;
248 return b.
nodes[current].internal;
254 if (b.
nodes[current].norm_Eh > b.
nodes[current].preds[class_index].norm_Ehk)
259 base.
learn(ec, b.
nodes[current].base_predictor);
266 b.
nodes[current].n++;
267 b.
nodes[current].preds[class_index].nk++;
269 b.
nodes[current].norm_Eh = (float)b.
nodes[current].Eh / b.
nodes[current].n;
270 b.
nodes[current].preds[class_index].norm_Ehk =
271 (
float)b.
nodes[current].preds[class_index].Ehk / b.
nodes[current].preds[class_index].nk;
280 std::cout <<
"badness! " << std::endl;
308 ec.
l.
simple = {FLT_MAX, 0.f, 0.f};
311 while (b.
nodes[cn].internal)
332 uint32_t class_index = 0;
333 ec.
l.
simple = {FLT_MAX, 0.f, 0.f};
338 train_node(b, base, ec, cn, class_index, depth);
343 b.
nodes[cn].min_count++;
357 fp = fopen(
"atxm_debug.csv",
"wt");
361 fprintf(fp,
"Node: %4d, Internal: %1d, Eh: %7.4f, n: %6d, \n", (
int)i, (
int)b->
nodes[i].internal,
364 fprintf(fp,
"Label:, ");
365 for (j = 0; j < b->
nodes[i].preds.
size(); j++)
367 fprintf(fp,
"%6d,", (
int)b->
nodes[i].preds[j].label);
371 fprintf(fp,
"Ehk:, ");
372 for (j = 0; j < b->
nodes[i].preds.
size(); j++)
374 fprintf(fp,
"%7.4f,", b->
nodes[i].preds[j].Ehk / b->
nodes[i].preds[j].nk);
380 fprintf(fp,
"nk:, ");
381 for (j = 0; j < b->
nodes[i].preds.
size(); j++)
383 fprintf(fp,
"%6d,", (
int)b->
nodes[i].preds[j].nk);
384 total += b->
nodes[i].preds[j].nk;
388 fprintf(fp,
"max(lab:cnt:tot):, %3d,%6d,%7d,\n", (
int)b->
nodes[i].max_count_label, (
int)b->
nodes[i].max_count,
390 fprintf(fp,
"left: %4d, right: %4d", (
int)b->
nodes[i].left, (
int)b->
nodes[i].right);
401 std::stringstream msg;
402 msg <<
"k = " << b.
k;
405 msg <<
"nodes = " << b.
nodes.
size() <<
" ";
406 uint32_t temp = (uint32_t)b.
nodes.
size();
417 msg <<
"progress = " << b.
progress <<
" ";
423 for (
size_t j = 0; j < b.
nodes.
size(); j++)
428 msg <<
" parent = " << n.
parent;
431 uint32_t temp = (uint32_t)n.
preds.
size();
433 msg <<
" preds = " << temp;
441 msg <<
" internal = " << n.
internal;
449 msg <<
" left = " << n.
left;
452 msg <<
" right = " << n.
right;
455 msg <<
" norm_Eh = " << n.
norm_Eh;
458 msg <<
" Eh = " << n.
Eh;
461 msg <<
" n = " << n.
n <<
"\n";
473 for (
size_t k = 0; k < n.
preds.
size(); k++)
477 msg <<
" Ehk = " << p.
Ehk;
480 msg <<
" norm_Ehk = " << p.
norm_Ehk;
483 msg <<
" nk = " << p.
nk;
486 msg <<
" label = " << p.
label;
498 auto data = scoped_calloc_or_throw<log_multi>();
500 new_options.add(
make_option(
"log_multi", data->k).keep().help(
"Use online tree for multiclass"))
501 .
add(
make_option(
"no_progress", data->progress).help(
"disable progressive validation"))
502 .
add(
make_option(
"swap_resistance", data->swap_resist).default_value(4).help(
"disable progressive validation"))
505 .help(
"higher = more resistance to swap, default=4"));
511 data->progress = !data->progress;
514 float loss_parameter = 0.5;
518 data->max_predictors = data->k - 1;
void predict(E &ec, size_t i=0)
void save_node_stats(log_multi &d)
bool operator<(node_pred v)
base_learner * make_base(learner< T, E > &base)
virtual void add_and_parse(const option_group_definition &group)=0
void train_node(log_multi &b, single_learner &base, example &ec, uint32_t ¤t, uint32_t &class_index, uint32_t)
v_array< node_pred > preds
single_learner * as_singleline(learner< T, E > *l)
MULTICLASS::label_t multi
bool operator==(node_pred v)
void push_back(const T &new_ele)
virtual bool was_supplied(const std::string &key)=0
base_learner * log_multi_setup(options_i &options, vw &all)
void update_min_count(log_multi &b, uint32_t node)
void display_tree_dfs(log_multi &b, const node &node, uint32_t depth)
uint32_t min_left_right(log_multi &b, const node &n)
int add(svm_params ¶ms, svm_example *fec)
void save_load_tree(log_multi &b, io_buf &model_file, bool read, bool text)
typed_option< T > make_option(std::string name, T &location)
void init_tree(log_multi &d)
void learn(log_multi &b, single_learner &base, example &ec)
learner< T, E > & init_multiclass_learner(free_ptr< T > &dat, L *base, void(*learn)(T &, L &, E &), void(*predict)(T &, L &, E &), parser *p, size_t ws, prediction_type::prediction_type_t pred_type=prediction_type::multiclass)
LEARNER::base_learner * setup_base(options_i &options, vw &all)
size_t sum_count_dfs(log_multi &b, const node &node)
void learn(E &ec, size_t i=0)
bool children(log_multi &b, uint32_t ¤t, uint32_t &class_index, uint32_t label)
loss_function * getLossFunction(vw &all, std::string funcName, float function_parameter)
size_t bin_text_read_write_fixed(io_buf &io, char *data, size_t len, const char *read_message, bool read, std::stringstream &msg, bool text)
void verify_min_dfs(log_multi &b, const node &node)
uint32_t find_switch_node(log_multi &b)
uint32_t descend(node &n, float prediction)
size_t unique_add_sorted(const T &new_ele)
void predict(log_multi &b, single_learner &base, example &ec)
bool operator>(node_pred v)