Skip to content

Commit a009254

Browse files
committed
Add PLI variant into the AfdMetricCalc
1 parent 55527e9 commit a009254

3 files changed

Lines changed: 191 additions & 25 deletions

File tree

src/core/algorithms/fd/afd_metric/afd_metric_calculator.cpp

Lines changed: 164 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <chrono>
44
#include <cmath>
55
#include <iterator>
6+
#include <variant>
67

78
#include "core/config/descriptions.h"
89
#include "core/config/equal_nulls/option.h"
@@ -30,13 +31,14 @@ void AFDMetricCalculator::RegisterOptions() {
3031
RegisterOption(config::kLhsIndicesOpt(&lhs_indices_, get_schema_cols));
3132
RegisterOption(config::kRhsIndicesOpt(&rhs_indices_, get_schema_cols));
3233
RegisterOption(Option{&metric_, kMetric, kDAFDMetric});
34+
RegisterOption(config::kUsePliwsOpt(&use_pliws_));
3335
}
3436

3537
void AFDMetricCalculator::MakeExecuteOptsAvailable() {
3638
using namespace config::names;
3739

38-
MakeOptionsAvailable(
39-
{kMetric, config::kLhsIndicesOpt.GetName(), config::kRhsIndicesOpt.GetName()});
40+
MakeOptionsAvailable({kMetric, config::kLhsIndicesOpt.GetName(),
41+
config::kRhsIndicesOpt.GetName(), "use_pliws"});
4042
}
4143

4244
void AFDMetricCalculator::LoadDataInternal() {
@@ -56,25 +58,52 @@ unsigned long long AFDMetricCalculator::ExecuteInternal() {
5658

5759
void AFDMetricCalculator::CalculateMetric() {
5860
auto num_rows = relation_->GetNumRows();
59-
auto lhs_pli = relation_->CalculatePLIWS(lhs_indices_);
60-
auto rhs_pli = relation_->CalculatePLIWS(rhs_indices_);
61-
62-
switch (metric_) {
63-
case AFDMetric::g2:
64-
result_ = CalculateG2(lhs_pli.get(), rhs_pli.get(), num_rows);
65-
break;
66-
case AFDMetric::tau:
67-
result_ = CalculateTau(lhs_pli.get(), rhs_pli.get(),
68-
lhs_pli->Intersect(rhs_pli.get()).get());
69-
break;
70-
case AFDMetric::mu_plus:
71-
result_ = CalculateMuPlus(lhs_pli.get(), rhs_pli.get(),
72-
lhs_pli->Intersect(rhs_pli.get()).get());
73-
break;
74-
case AFDMetric::fi:
75-
result_ = CalculateFI(lhs_pli.get(), rhs_pli.get(), num_rows);
76-
break;
61+
62+
using PLIVariant = std::variant<std::shared_ptr<model::PLI const>,
63+
std::shared_ptr<model::PLIWithSingletons const>>;
64+
PLIVariant lhs_pli;
65+
PLIVariant rhs_pli;
66+
67+
if (use_pliws_) {
68+
lhs_pli = relation_->CalculatePLIWS(lhs_indices_);
69+
rhs_pli = relation_->CalculatePLIWS(rhs_indices_);
70+
} else {
71+
lhs_pli = relation_->CalculatePLI(lhs_indices_);
72+
rhs_pli = relation_->CalculatePLI(rhs_indices_);
7773
}
74+
75+
auto visitor = [this, num_rows](auto&& lhs_ptr, auto&& rhs_ptr) {
76+
using LhsType = std::decay_t<decltype(lhs_ptr)>;
77+
using RhsType = std::decay_t<decltype(rhs_ptr)>;
78+
79+
if constexpr (std::is_same_v<LhsType, RhsType>) {
80+
auto* lhs_raw = lhs_ptr.get();
81+
auto* rhs_raw = rhs_ptr.get();
82+
83+
switch (metric_) {
84+
case AFDMetric::g2:
85+
result_ = CalculateG2(lhs_raw, rhs_raw, num_rows);
86+
break;
87+
case AFDMetric::tau: {
88+
auto joint = lhs_raw->Intersect(rhs_raw);
89+
result_ = CalculateTau(lhs_raw, rhs_raw, joint.get());
90+
break;
91+
}
92+
case AFDMetric::mu_plus: {
93+
auto joint = lhs_raw->Intersect(rhs_raw);
94+
result_ = CalculateMuPlus(lhs_raw, rhs_raw, joint.get());
95+
break;
96+
}
97+
case AFDMetric::fi:
98+
result_ = CalculateFI(lhs_raw, rhs_raw, num_rows);
99+
break;
100+
}
101+
} else {
102+
std::terminate();
103+
}
104+
};
105+
106+
std::visit(visitor, lhs_pli, rhs_pli);
78107
}
79108

80109
long double AFDMetricCalculator::CalculateG2(model::PLI const* lhs_pli, model::PLI const* rhs_pli,
@@ -95,7 +124,7 @@ long double AFDMetricCalculator::CalculateG2(model::PLI const* lhs_pli, model::P
95124
return num_error_rows / num_rows;
96125
}
97126

98-
long double AFDMetricCalculator::CalculatePdepSelf(model::PLIWithSingletons const* x_pli) {
127+
long double AFDMetricCalculator::CalculatePdepSelf(model::PLI const* x_pli) {
99128
size_t n = x_pli->GetRelationSize();
100129
config::ErrorType sum = 0;
101130
std::size_t cluster_rows_count = 0;
@@ -150,6 +179,45 @@ long double AFDMetricCalculator::CalculatePdepMeasure(model::PLIWithSingletons c
150179
return (sum / static_cast<config::ErrorType>(n));
151180
}
152181

182+
long double AFDMetricCalculator::CalculatePdepMeasure(model::PositionListIndex const* x_pli,
183+
model::PositionListIndex const* xa_pli) {
184+
std::deque<Cluster> xa_index = xa_pli->GetIndex();
185+
std::deque<Cluster> x_index = x_pli->GetIndex();
186+
size_t n = x_pli->GetRelationSize();
187+
188+
config::ErrorType sum = 0;
189+
190+
std::unordered_map<int, size_t> x_frequencies;
191+
192+
int x_value_id = 1;
193+
for (Cluster const& x_cluster : x_index) {
194+
x_frequencies[x_value_id++] = x_cluster.size();
195+
}
196+
197+
x_frequencies[model::PositionListIndex::kSingletonValueId] = 1;
198+
199+
auto x_prob = x_pli->CalculateAndGetProbingTable();
200+
201+
auto get_x_freq_by_tuple_ind{[&x_prob, &x_frequencies](int tuple_ind) {
202+
int value_id = x_prob->at(tuple_ind);
203+
return static_cast<config::ErrorType>(x_frequencies[value_id]);
204+
}};
205+
206+
for (Cluster const& xa_cluster : xa_index) {
207+
config::ErrorType num = xa_cluster.size() * xa_cluster.size();
208+
config::ErrorType denum = get_x_freq_by_tuple_ind(xa_cluster.front());
209+
sum += num / denum;
210+
}
211+
212+
auto xa_prob = xa_pli->CalculateAndGetProbingTable();
213+
for (size_t i = 0; i < xa_prob->size(); i++) {
214+
if (xa_prob->at(i) == 0) {
215+
sum += 1 / get_x_freq_by_tuple_ind(i);
216+
}
217+
}
218+
return (sum / static_cast<config::ErrorType>(n));
219+
}
220+
153221
long double AFDMetricCalculator::CalculateTau(model::PLIWS const* lhs_pli,
154222
model::PLIWS const* rhs_pli,
155223
model::PLIWS const* joint_pli) {
@@ -161,6 +229,16 @@ long double AFDMetricCalculator::CalculateTau(model::PLIWS const* lhs_pli,
161229
return (p2 - p1) / (1 - p1);
162230
}
163231

232+
long double AFDMetricCalculator::CalculateTau(model::PLI const* lhs_pli, model::PLI const* rhs_pli,
233+
model::PLI const* joint_pli) {
234+
auto p1 = CalculatePdepSelf(rhs_pli);
235+
if (p1 == 1) return 1;
236+
237+
auto p2 = CalculatePdepMeasure(lhs_pli, joint_pli);
238+
239+
return (p2 - p1) / (1 - p1);
240+
}
241+
164242
long double AFDMetricCalculator::CalculateMuPlus(model::PLIWS const* lhs_pli,
165243
model::PLIWS const* rhs_pli,
166244
model::PLIWS const* joint_pli) {
@@ -186,6 +264,33 @@ long double AFDMetricCalculator::CalculateMuPlus(model::PLIWS const* lhs_pli,
186264
return std::max(mu, 0.L);
187265
}
188266

267+
long double AFDMetricCalculator::CalculateMuPlus(model::PositionListIndex const* x_pli,
268+
model::PositionListIndex const* a_pli,
269+
model::PositionListIndex const* xa_pli) {
270+
config::ErrorType pdep_y = CalculatePdepSelf(a_pli);
271+
if (pdep_y == 1) return 1;
272+
273+
config::ErrorType pdep_xy = CalculatePdepMeasure(x_pli, xa_pli);
274+
275+
size_t n = x_pli->GetRelationSize();
276+
std::size_t cluster_rows_count = 0;
277+
std::deque<Cluster> const& x_index = x_pli->GetIndex();
278+
size_t k = x_index.size();
279+
280+
for (Cluster const& x_cluster : x_index) {
281+
cluster_rows_count += x_cluster.size();
282+
}
283+
284+
std::size_t unique_rows = x_pli->GetRelationSize() - cluster_rows_count;
285+
k += unique_rows;
286+
287+
if (k == n) return 1;
288+
289+
config::ErrorType mu = 1 - (1 - pdep_xy) / (1 - pdep_y) * (n - 1) / (n - k);
290+
config::ErrorType mu_plus = std::max(0., mu);
291+
return mu_plus;
292+
}
293+
189294
long double AFDMetricCalculator::CalculateFI(model::PLIWS const* lhs_pli,
190295
model::PLIWS const* rhs_pli, size_t num_rows) {
191296
if (num_rows <= 0) throw std::invalid_argument("received unpositive number of rows");
@@ -219,4 +324,42 @@ long double AFDMetricCalculator::CalculateFI(model::PLIWS const* lhs_pli,
219324
return mutual_information / entropy;
220325
}
221326

327+
long double AFDMetricCalculator::CalculateFI(model::PLI const* lhs_pli, model::PLI const* rhs_pli,
328+
size_t num_rows) {
329+
if (num_rows <= 0) throw std::invalid_argument("received unpositive number of rows");
330+
331+
if (rhs_pli->GetNumCluster() < 2) {
332+
return 0.L;
333+
}
334+
335+
auto entropy = rhs_pli->GetEntropy();
336+
337+
std::deque<Cluster> rhs_clusters{rhs_pli->GetIndex()};
338+
for (auto& y : rhs_clusters) {
339+
std::sort(y.begin(), y.end());
340+
}
341+
342+
auto conditional_entropy = 0.L;
343+
for (auto& x : std::deque<Cluster>{lhs_pli->GetIndex()}) {
344+
std::sort(x.begin(), x.end());
345+
auto log_x = std::log(x.size());
346+
347+
size_t xy_calc_cnt = 0;
348+
for (auto const& y : rhs_clusters) {
349+
model::PositionListIndex::Cluster xy;
350+
std::set_intersection(x.begin(), x.end(), y.begin(), y.end(), std::back_inserter(xy));
351+
352+
auto size = (long double)xy.size();
353+
if (size == 0.L) continue;
354+
xy_calc_cnt += size;
355+
conditional_entropy -= size * (std::log(size) - log_x);
356+
}
357+
358+
conditional_entropy -= (long double)(x.size() - xy_calc_cnt) * (-log_x);
359+
}
360+
conditional_entropy /= num_rows;
361+
auto mutual_information = entropy - conditional_entropy;
362+
return mutual_information / entropy;
363+
}
364+
222365
} // namespace algos::afd_metric_calculator

src/core/algorithms/fd/afd_metric/afd_metric_calculator.h

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "core/config/equal_nulls/type.h"
1010
#include "core/config/indices/type.h"
1111
#include "core/config/tabular_data/input_table_type.h"
12+
#include "core/config/use_pliws/option.h"
1213
#include "core/model/table/column_layout_relation_data.h"
1314
#include "core/model/table/position_list_index.h"
1415
#include "core/model/table/position_list_index_with_singletons.h"
@@ -20,6 +21,7 @@ class AFDMetricCalculator : public Algorithm {
2021
config::InputTable input_table_;
2122

2223
AFDMetric metric_ = AFDMetric::_values()[0];
24+
bool use_pliws_ = false;
2325
config::IndicesType lhs_indices_;
2426
config::IndicesType rhs_indices_;
2527

@@ -45,23 +47,35 @@ class AFDMetricCalculator : public Algorithm {
4547
size_t num_rows, std::deque<model::PositionListIndex::Cluster>&& lhs_clusters,
4648
std::deque<model::PositionListIndex::Cluster>&& rhs_clusters);
4749

48-
static long double CalculatePdepSelf(model::PLIWithSingletons const* x_pli);
50+
static long double CalculatePdepSelf(model::PLI const* x_pli);
4951

5052
static long double CalculatePdepMeasure(model::PLIWithSingletons const* x_pli,
5153
model::PLIWithSingletons const* xa_pli);
5254

55+
static long double CalculatePdepMeasure(model::PositionListIndex const* x_pli,
56+
model::PositionListIndex const* xa_pli);
57+
5358
static long double CalculateG2(model::PLI const* lhs_pli, model::PLI const* rhs_pli,
5459
size_t num_rows);
5560

5661
static long double CalculateTau(model::PLIWS const* lhs_pli, model::PLIWS const* rhs_pli,
5762
model::PLIWS const* joint_pli);
5863

64+
static long double CalculateTau(model::PLI const* lhs_pli, model::PLI const* rhs_pli,
65+
model::PLI const* joint_pli);
66+
5967
static long double CalculateMuPlus(model::PLIWS const* lhs_pli, model::PLIWS const* rhs_pli,
6068
model::PLIWS const* joint_pli);
6169

70+
static long double CalculateMuPlus(model::PLI const* lhs_pli, model::PLI const* rhs_pli,
71+
model::PLI const* joint_pli);
72+
6273
static long double CalculateFI(model::PLIWS const* lhs_pli, model::PLIWS const* rhs_pli,
6374
size_t num_rows);
6475

76+
static long double CalculateFI(model::PLI const* lhs_pli, model::PLI const* rhs_pli,
77+
size_t num_rows);
78+
6579
long double GetResult() const {
6680
return result_;
6781
}

src/tests/unit/test_afd_metric_calculator.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@ struct AFDMetricCalculatorParams {
1515
long double const expected = 0.L;
1616

1717
AFDMetricCalculatorParams(config::IndicesType lhs_indices, config::IndicesType rhs_indices,
18-
AFDMetric metric, long double expected,
18+
AFDMetric metric, long double expected, bool use_pliws = false,
1919
CSVConfig const& csv_config = kTestFD)
2020
: params({{kCsvConfig, csv_config},
2121
{kLhsIndices, std::move(lhs_indices)},
2222
{kRhsIndices, std::move(rhs_indices)},
2323
{kEqualNulls, true},
24-
{kMetric, metric}}),
24+
{kMetric, metric},
25+
{"use_pliws", use_pliws}}),
2526
expected(expected) {}
2627
};
2728

@@ -41,13 +42,21 @@ INSTANTIATE_TEST_SUITE_P(
4142
AFDMetricCalculatorTestSuite, TestAFDMetrics,
4243
::testing::Values(
4344
AFDMetricCalculatorParams({4}, {3}, AFDMetric::tau, 78.L/90),
45+
AFDMetricCalculatorParams({4}, {3}, AFDMetric::tau, 78.L/90, true),
4446
AFDMetricCalculatorParams({4}, {3}, AFDMetric::g2, 1.L/6),
47+
AFDMetricCalculatorParams({4}, {3}, AFDMetric::g2, 1.L/6, true),
4548
AFDMetricCalculatorParams({4}, {3}, AFDMetric::fi, 1 - std::log(4) / std::log(746496)),
49+
AFDMetricCalculatorParams({4}, {3}, AFDMetric::fi, 1 - std::log(4) / std::log(746496), true),
4650
AFDMetricCalculatorParams({4}, {3}, AFDMetric::mu_plus, 498.L/630),
51+
AFDMetricCalculatorParams({4}, {3}, AFDMetric::mu_plus, 498.L/630, true),
4752
AFDMetricCalculatorParams({3}, {4}, AFDMetric::tau, 54.L/114),
53+
AFDMetricCalculatorParams({3}, {4}, AFDMetric::tau, 54.L/114, true),
4854
AFDMetricCalculatorParams({3}, {4}, AFDMetric::g2, 5.L/6),
55+
AFDMetricCalculatorParams({3}, {4}, AFDMetric::g2, 5.L/6, true),
4956
AFDMetricCalculatorParams({3}, {4}, AFDMetric::fi, std::log(432) / std::log(13824)),
50-
AFDMetricCalculatorParams({3}, {4}, AFDMetric::mu_plus, 252.L/912)
57+
AFDMetricCalculatorParams({3}, {4}, AFDMetric::fi, std::log(432) / std::log(13824), true),
58+
AFDMetricCalculatorParams({3}, {4}, AFDMetric::mu_plus, 252.L/912),
59+
AFDMetricCalculatorParams({3}, {4}, AFDMetric::mu_plus, 252.L/912, true)
5160
));
5261
// clang-format on
5362

0 commit comments

Comments
 (0)