33#include < chrono>
44#include < cmath>
55#include < iterator>
6+ #include < variant>
67
78#include " core/config/descriptions.h"
89#include " core/config/equal_nulls/option.h"
@@ -30,13 +31,14 @@ void AFDMetricCalculator::RegisterOptions() {
3031 RegisterOption (config::kLhsIndicesOpt (&lhs_indices_, get_schema_cols));
3132 RegisterOption (config::kRhsIndicesOpt (&rhs_indices_, get_schema_cols));
3233 RegisterOption (Option{&metric_, kMetric , kDAFDMetric });
34+ RegisterOption (config::kUsePliwsOpt (&use_pliws_));
3335}
3436
3537void AFDMetricCalculator::MakeExecuteOptsAvailable () {
3638 using namespace config ::names;
3739
38- MakeOptionsAvailable (
39- { kMetric , config:: kLhsIndicesOpt . GetName (), config::kRhsIndicesOpt .GetName ()});
40+ MakeOptionsAvailable ({ kMetric , config:: kLhsIndicesOpt . GetName (),
41+ config::kRhsIndicesOpt .GetName (), " use_pliws " });
4042}
4143
4244void AFDMetricCalculator::LoadDataInternal () {
@@ -56,25 +58,52 @@ unsigned long long AFDMetricCalculator::ExecuteInternal() {
5658
5759void AFDMetricCalculator::CalculateMetric () {
5860 auto num_rows = relation_->GetNumRows ();
59- auto lhs_pli = relation_->CalculatePLIWS (lhs_indices_);
60- auto rhs_pli = relation_->CalculatePLIWS (rhs_indices_);
61-
62- switch (metric_) {
63- case AFDMetric::g2:
64- result_ = CalculateG2 (lhs_pli.get (), rhs_pli.get (), num_rows);
65- break ;
66- case AFDMetric::tau:
67- result_ = CalculateTau (lhs_pli.get (), rhs_pli.get (),
68- lhs_pli->Intersect (rhs_pli.get ()).get ());
69- break ;
70- case AFDMetric::mu_plus:
71- result_ = CalculateMuPlus (lhs_pli.get (), rhs_pli.get (),
72- lhs_pli->Intersect (rhs_pli.get ()).get ());
73- break ;
74- case AFDMetric::fi:
75- result_ = CalculateFI (lhs_pli.get (), rhs_pli.get (), num_rows);
76- break ;
61+
62+ using PLIVariant = std::variant<std::shared_ptr<model::PLI const >,
63+ std::shared_ptr<model::PLIWithSingletons const >>;
64+ PLIVariant lhs_pli;
65+ PLIVariant rhs_pli;
66+
67+ if (use_pliws_) {
68+ lhs_pli = relation_->CalculatePLIWS (lhs_indices_);
69+ rhs_pli = relation_->CalculatePLIWS (rhs_indices_);
70+ } else {
71+ lhs_pli = relation_->CalculatePLI (lhs_indices_);
72+ rhs_pli = relation_->CalculatePLI (rhs_indices_);
7773 }
74+
75+ auto visitor = [this , num_rows](auto && lhs_ptr, auto && rhs_ptr) {
76+ using LhsType = std::decay_t <decltype (lhs_ptr)>;
77+ using RhsType = std::decay_t <decltype (rhs_ptr)>;
78+
79+ if constexpr (std::is_same_v<LhsType, RhsType>) {
80+ auto * lhs_raw = lhs_ptr.get ();
81+ auto * rhs_raw = rhs_ptr.get ();
82+
83+ switch (metric_) {
84+ case AFDMetric::g2:
85+ result_ = CalculateG2 (lhs_raw, rhs_raw, num_rows);
86+ break ;
87+ case AFDMetric::tau: {
88+ auto joint = lhs_raw->Intersect (rhs_raw);
89+ result_ = CalculateTau (lhs_raw, rhs_raw, joint.get ());
90+ break ;
91+ }
92+ case AFDMetric::mu_plus: {
93+ auto joint = lhs_raw->Intersect (rhs_raw);
94+ result_ = CalculateMuPlus (lhs_raw, rhs_raw, joint.get ());
95+ break ;
96+ }
97+ case AFDMetric::fi:
98+ result_ = CalculateFI (lhs_raw, rhs_raw, num_rows);
99+ break ;
100+ }
101+ } else {
102+ std::terminate ();
103+ }
104+ };
105+
106+ std::visit (visitor, lhs_pli, rhs_pli);
78107}
79108
80109long double AFDMetricCalculator::CalculateG2 (model::PLI const * lhs_pli, model::PLI const * rhs_pli,
@@ -95,7 +124,7 @@ long double AFDMetricCalculator::CalculateG2(model::PLI const* lhs_pli, model::P
95124 return num_error_rows / num_rows;
96125}
97126
98- long double AFDMetricCalculator::CalculatePdepSelf (model::PLIWithSingletons const * x_pli) {
127+ long double AFDMetricCalculator::CalculatePdepSelf (model::PLI const * x_pli) {
99128 size_t n = x_pli->GetRelationSize ();
100129 config::ErrorType sum = 0 ;
101130 std::size_t cluster_rows_count = 0 ;
@@ -150,6 +179,45 @@ long double AFDMetricCalculator::CalculatePdepMeasure(model::PLIWithSingletons c
150179 return (sum / static_cast <config::ErrorType>(n));
151180}
152181
182+ long double AFDMetricCalculator::CalculatePdepMeasure (model::PositionListIndex const * x_pli,
183+ model::PositionListIndex const * xa_pli) {
184+ std::deque<Cluster> xa_index = xa_pli->GetIndex ();
185+ std::deque<Cluster> x_index = x_pli->GetIndex ();
186+ size_t n = x_pli->GetRelationSize ();
187+
188+ config::ErrorType sum = 0 ;
189+
190+ std::unordered_map<int , size_t > x_frequencies;
191+
192+ int x_value_id = 1 ;
193+ for (Cluster const & x_cluster : x_index) {
194+ x_frequencies[x_value_id++] = x_cluster.size ();
195+ }
196+
197+ x_frequencies[model::PositionListIndex::kSingletonValueId ] = 1 ;
198+
199+ auto x_prob = x_pli->CalculateAndGetProbingTable ();
200+
201+ auto get_x_freq_by_tuple_ind{[&x_prob, &x_frequencies](int tuple_ind) {
202+ int value_id = x_prob->at (tuple_ind);
203+ return static_cast <config::ErrorType>(x_frequencies[value_id]);
204+ }};
205+
206+ for (Cluster const & xa_cluster : xa_index) {
207+ config::ErrorType num = xa_cluster.size () * xa_cluster.size ();
208+ config::ErrorType denum = get_x_freq_by_tuple_ind (xa_cluster.front ());
209+ sum += num / denum;
210+ }
211+
212+ auto xa_prob = xa_pli->CalculateAndGetProbingTable ();
213+ for (size_t i = 0 ; i < xa_prob->size (); i++) {
214+ if (xa_prob->at (i) == 0 ) {
215+ sum += 1 / get_x_freq_by_tuple_ind (i);
216+ }
217+ }
218+ return (sum / static_cast <config::ErrorType>(n));
219+ }
220+
153221long double AFDMetricCalculator::CalculateTau (model::PLIWS const * lhs_pli,
154222 model::PLIWS const * rhs_pli,
155223 model::PLIWS const * joint_pli) {
@@ -161,6 +229,16 @@ long double AFDMetricCalculator::CalculateTau(model::PLIWS const* lhs_pli,
161229 return (p2 - p1) / (1 - p1);
162230}
163231
232+ long double AFDMetricCalculator::CalculateTau (model::PLI const * lhs_pli, model::PLI const * rhs_pli,
233+ model::PLI const * joint_pli) {
234+ auto p1 = CalculatePdepSelf (rhs_pli);
235+ if (p1 == 1 ) return 1 ;
236+
237+ auto p2 = CalculatePdepMeasure (lhs_pli, joint_pli);
238+
239+ return (p2 - p1) / (1 - p1);
240+ }
241+
164242long double AFDMetricCalculator::CalculateMuPlus (model::PLIWS const * lhs_pli,
165243 model::PLIWS const * rhs_pli,
166244 model::PLIWS const * joint_pli) {
@@ -186,6 +264,33 @@ long double AFDMetricCalculator::CalculateMuPlus(model::PLIWS const* lhs_pli,
186264 return std::max (mu, 0 .L );
187265}
188266
267+ long double AFDMetricCalculator::CalculateMuPlus (model::PositionListIndex const * x_pli,
268+ model::PositionListIndex const * a_pli,
269+ model::PositionListIndex const * xa_pli) {
270+ config::ErrorType pdep_y = CalculatePdepSelf (a_pli);
271+ if (pdep_y == 1 ) return 1 ;
272+
273+ config::ErrorType pdep_xy = CalculatePdepMeasure (x_pli, xa_pli);
274+
275+ size_t n = x_pli->GetRelationSize ();
276+ std::size_t cluster_rows_count = 0 ;
277+ std::deque<Cluster> const & x_index = x_pli->GetIndex ();
278+ size_t k = x_index.size ();
279+
280+ for (Cluster const & x_cluster : x_index) {
281+ cluster_rows_count += x_cluster.size ();
282+ }
283+
284+ std::size_t unique_rows = x_pli->GetRelationSize () - cluster_rows_count;
285+ k += unique_rows;
286+
287+ if (k == n) return 1 ;
288+
289+ config::ErrorType mu = 1 - (1 - pdep_xy) / (1 - pdep_y) * (n - 1 ) / (n - k);
290+ config::ErrorType mu_plus = std::max (0 ., mu);
291+ return mu_plus;
292+ }
293+
189294long double AFDMetricCalculator::CalculateFI (model::PLIWS const * lhs_pli,
190295 model::PLIWS const * rhs_pli, size_t num_rows) {
191296 if (num_rows <= 0 ) throw std::invalid_argument (" received unpositive number of rows" );
@@ -219,4 +324,42 @@ long double AFDMetricCalculator::CalculateFI(model::PLIWS const* lhs_pli,
219324 return mutual_information / entropy;
220325}
221326
327+ long double AFDMetricCalculator::CalculateFI (model::PLI const * lhs_pli, model::PLI const * rhs_pli,
328+ size_t num_rows) {
329+ if (num_rows <= 0 ) throw std::invalid_argument (" received unpositive number of rows" );
330+
331+ if (rhs_pli->GetNumCluster () < 2 ) {
332+ return 0 .L ;
333+ }
334+
335+ auto entropy = rhs_pli->GetEntropy ();
336+
337+ std::deque<Cluster> rhs_clusters{rhs_pli->GetIndex ()};
338+ for (auto & y : rhs_clusters) {
339+ std::sort (y.begin (), y.end ());
340+ }
341+
342+ auto conditional_entropy = 0 .L ;
343+ for (auto & x : std::deque<Cluster>{lhs_pli->GetIndex ()}) {
344+ std::sort (x.begin (), x.end ());
345+ auto log_x = std::log (x.size ());
346+
347+ size_t xy_calc_cnt = 0 ;
348+ for (auto const & y : rhs_clusters) {
349+ model::PositionListIndex::Cluster xy;
350+ std::set_intersection (x.begin (), x.end (), y.begin (), y.end (), std::back_inserter (xy));
351+
352+ auto size = (long double )xy.size ();
353+ if (size == 0 .L ) continue ;
354+ xy_calc_cnt += size;
355+ conditional_entropy -= size * (std::log (size) - log_x);
356+ }
357+
358+ conditional_entropy -= (long double )(x.size () - xy_calc_cnt) * (-log_x);
359+ }
360+ conditional_entropy /= num_rows;
361+ auto mutual_information = entropy - conditional_entropy;
362+ return mutual_information / entropy;
363+ }
364+
222365} // namespace algos::afd_metric_calculator
0 commit comments