Skip to content

Commit b6203e3

Browse files
committed
Fix comments
1 parent b9d7ed5 commit b6203e3

10 files changed

Lines changed: 142 additions & 119 deletions

File tree

src/core/algorithms/fd/afd_metric/afd_metric_calculator.cpp

Lines changed: 36 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ void AFDMetricCalculator::MakeExecuteOptsAvailable() {
3838
using namespace config::names;
3939

4040
MakeOptionsAvailable({kMetric, config::kLhsIndicesOpt.GetName(),
41-
config::kRhsIndicesOpt.GetName(), "use_pliws"});
41+
config::kRhsIndicesOpt.GetName(), config::kUsePliwsOpt.GetName()});
4242
}
4343

4444
void AFDMetricCalculator::LoadDataInternal() {
@@ -56,54 +56,47 @@ unsigned long long AFDMetricCalculator::ExecuteInternal() {
5656
return elapsed_milliseconds;
5757
}
5858

59+
namespace {
60+
61+
template <typename PLIT>
62+
auto CalculateMetric(PLIT&& lhs_ptr, PLIT&& rhs_ptr, AFDMetric metric, size_t num_rows) {
63+
auto* lhs_raw = lhs_ptr.get();
64+
auto* rhs_raw = rhs_ptr.get();
65+
66+
switch (metric) {
67+
case AFDMetric::g2:
68+
return AFDMetricCalculator::CalculateG2(lhs_raw, rhs_raw, num_rows);
69+
case AFDMetric::tau: {
70+
auto joint = lhs_raw->Intersect(rhs_raw);
71+
return AFDMetricCalculator::CalculateTau(lhs_raw, rhs_raw, joint.get());
72+
}
73+
case AFDMetric::mu_plus: {
74+
auto joint = lhs_raw->Intersect(rhs_raw);
75+
return AFDMetricCalculator::CalculateMuPlus(lhs_raw, rhs_raw, joint.get());
76+
}
77+
case AFDMetric::fi:
78+
return AFDMetricCalculator::CalculateFI(lhs_raw, rhs_raw, num_rows);
79+
}
80+
81+
assert(false);
82+
__builtin_unreachable();
83+
}
84+
} // namespace
85+
5986
void AFDMetricCalculator::CalculateMetric() {
6087
auto num_rows = relation_->GetNumRows();
6188

62-
using PLIVariant = std::variant<std::shared_ptr<model::PLI const>,
63-
std::shared_ptr<model::PLIWithSingletons const>>;
64-
PLIVariant lhs_pli;
65-
PLIVariant rhs_pli;
66-
6789
if (use_pliws_) {
68-
lhs_pli = relation_->CalculatePLIWS(lhs_indices_);
69-
rhs_pli = relation_->CalculatePLIWS(rhs_indices_);
90+
auto lhs_pli = relation_->CalculatePLIWS(lhs_indices_);
91+
auto rhs_pli = relation_->CalculatePLIWS(rhs_indices_);
92+
result_ = ::algos::afd_metric_calculator::CalculateMetric(lhs_pli, rhs_pli, metric_,
93+
num_rows);
7094
} else {
71-
lhs_pli = relation_->CalculatePLI(lhs_indices_);
72-
rhs_pli = relation_->CalculatePLI(rhs_indices_);
95+
auto lhs_pli = relation_->CalculatePLI(lhs_indices_);
96+
auto rhs_pli = relation_->CalculatePLI(rhs_indices_);
97+
result_ = ::algos::afd_metric_calculator::CalculateMetric(lhs_pli, rhs_pli, metric_,
98+
num_rows);
7399
}
74-
75-
auto visitor = [this, num_rows](auto&& lhs_ptr, auto&& rhs_ptr) {
76-
using LhsType = std::decay_t<decltype(lhs_ptr)>;
77-
using RhsType = std::decay_t<decltype(rhs_ptr)>;
78-
79-
if constexpr (std::is_same_v<LhsType, RhsType>) {
80-
auto* lhs_raw = lhs_ptr.get();
81-
auto* rhs_raw = rhs_ptr.get();
82-
83-
switch (metric_) {
84-
case AFDMetric::g2:
85-
result_ = CalculateG2(lhs_raw, rhs_raw, num_rows);
86-
break;
87-
case AFDMetric::tau: {
88-
auto joint = lhs_raw->Intersect(rhs_raw);
89-
result_ = CalculateTau(lhs_raw, rhs_raw, joint.get());
90-
break;
91-
}
92-
case AFDMetric::mu_plus: {
93-
auto joint = lhs_raw->Intersect(rhs_raw);
94-
result_ = CalculateMuPlus(lhs_raw, rhs_raw, joint.get());
95-
break;
96-
}
97-
case AFDMetric::fi:
98-
result_ = CalculateFI(lhs_raw, rhs_raw, num_rows);
99-
break;
100-
}
101-
} else {
102-
std::terminate();
103-
}
104-
};
105-
106-
std::visit(visitor, lhs_pli, rhs_pli);
107100
}
108101

109102
long double AFDMetricCalculator::CalculateG2(model::PLI const* lhs_pli, model::PLI const* rhs_pli,

src/core/algorithms/fd/tane/pfdtane.cpp

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,21 +27,16 @@ config::ErrorType PFDTane::CalculateZeroAryFdError(ColumnData const* rhs) {
2727
return CalculateZeroAryPFDError(rhs);
2828
}
2929

30-
config::ErrorType PFDTane::CalculateFdError(tane::PLIVariantPtr lhs_pli,
31-
tane::PLIVariantPtr rhs_pli,
32-
tane::PLIVariantPtr joint_pli) {
33-
auto visitor = [this](auto&& lhs_ptr, auto&& rhs_ptr, auto&& joint_ptr) -> config::ErrorType {
34-
using LhsType = std::decay_t<decltype(lhs_ptr)>;
35-
using RhsType = std::decay_t<decltype(rhs_ptr)>;
36-
using JointType = std::decay_t<decltype(joint_ptr)>;
30+
config::ErrorType PFDTane::CalculateFdError(model::PLIWS const* lhs_pli,
31+
[[maybe_unused]] model::PLIWS const* rhs_pli,
32+
model::PLIWS const* joint_pli) {
33+
return CalculatePFDError(lhs_pli, joint_pli, pfd_error_measure_);
34+
}
3735

38-
if constexpr (std::is_same_v<LhsType, RhsType> && std::is_same_v<RhsType, JointType>) {
39-
return CalculatePFDError(lhs_ptr, joint_ptr, pfd_error_measure_);
40-
} else {
41-
std::terminate();
42-
}
43-
};
44-
return std::visit(visitor, lhs_pli, rhs_pli, joint_pli);
36+
config::ErrorType PFDTane::CalculateFdError(model::PLI const* lhs_pli,
37+
[[maybe_unused]] model::PLI const* rhs_pli,
38+
model::PLI const* joint_pli) {
39+
return CalculatePFDError(lhs_pli, joint_pli, pfd_error_measure_);
4540
}
4641

4742
config::ErrorType PFDTane::CalculateZeroAryPFDError(ColumnData const* rhs) {

src/core/algorithms/fd/tane/pfdtane.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,12 @@ class PFDTane : public tane::TaneCommon {
1515
void RegisterOptions();
1616
void MakeExecuteOptsAvailableFDInternal() final;
1717
config::ErrorType CalculateZeroAryFdError(ColumnData const* rhs) override;
18-
config::ErrorType CalculateFdError(tane::PLIVariantPtr lhs_pli, tane::PLIVariantPtr rhs_pli,
19-
tane::PLIVariantPtr joint_pli) override;
18+
config::ErrorType CalculateFdError(model::PLI const* lhs_pli,
19+
[[maybe_unused]] model::PLI const* rhs_pli,
20+
model::PLI const* joint_pli) override;
21+
config::ErrorType CalculateFdError(model::PLIWS const* lhs_pli,
22+
[[maybe_unused]] model::PLIWS const* rhs_pli,
23+
model::PLIWS const* joint_pli) override;
2024

2125
public:
2226
PFDTane();

src/core/algorithms/fd/tane/tane.cpp

Lines changed: 36 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -24,36 +24,42 @@ config::ErrorType Tane::CalculateZeroAryFdError(ColumnData const* rhs) {
2424
return 1;
2525
}
2626

27-
config::ErrorType Tane::CalculateFdError(tane::PLIVariantPtr lhs_pli, tane::PLIVariantPtr rhs_pli,
28-
tane::PLIVariantPtr joint_pli) {
29-
auto visitor = [this](auto&& lhs_pli, auto&& rhs_pli, auto&& joint_pli) -> config::ErrorType {
30-
using LhsType = std::decay_t<decltype(lhs_pli)>;
31-
using RhsType = std::decay_t<decltype(rhs_pli)>;
32-
using JointType = std::decay_t<decltype(joint_pli)>;
33-
34-
if constexpr (std::is_same_v<LhsType, RhsType> and std::is_same_v<RhsType, JointType>) {
35-
switch (afd_error_measure_) {
36-
case +AfdErrorMeasure::pdep:
37-
return 1 - afd_metric_calculator::AFDMetricCalculator::CalculatePdepMeasure(
38-
lhs_pli, joint_pli);
39-
case +AfdErrorMeasure::tau:
40-
return 1 - afd_metric_calculator::AFDMetricCalculator::CalculateTau(
41-
lhs_pli, rhs_pli, joint_pli);
42-
case +AfdErrorMeasure::mu_plus:
43-
return 1 - afd_metric_calculator::AFDMetricCalculator::CalculateMuPlus(
44-
lhs_pli, rhs_pli, joint_pli);
45-
case +AfdErrorMeasure::rho:
46-
return 1 - CalculateRhoMeasure(lhs_pli, joint_pli);
47-
default:
48-
return CalculateG1Error(lhs_pli, joint_pli,
49-
relation_.get()->GetNumTuplePairs());
50-
}
51-
} else {
52-
std::terminate();
53-
}
54-
};
55-
56-
return std::visit(visitor, lhs_pli, rhs_pli, joint_pli);
27+
config::ErrorType Tane::CalculateFdError(model::PLI const* lhs_pli, model::PLI const* rhs_pli,
28+
model::PLI const* joint_pli) {
29+
switch (afd_error_measure_) {
30+
case +AfdErrorMeasure::pdep:
31+
return 1 - afd_metric_calculator::AFDMetricCalculator::CalculatePdepMeasure(lhs_pli,
32+
joint_pli);
33+
case +AfdErrorMeasure::tau:
34+
return 1 - afd_metric_calculator::AFDMetricCalculator::CalculateTau(lhs_pli, rhs_pli,
35+
joint_pli);
36+
case +AfdErrorMeasure::mu_plus:
37+
return 1 - afd_metric_calculator::AFDMetricCalculator::CalculateMuPlus(lhs_pli, rhs_pli,
38+
joint_pli);
39+
case +AfdErrorMeasure::rho:
40+
return 1 - CalculateRhoMeasure(lhs_pli, joint_pli);
41+
default:
42+
return CalculateG1Error(lhs_pli, joint_pli, relation_.get()->GetNumTuplePairs());
43+
}
44+
}
45+
46+
config::ErrorType Tane::CalculateFdError(model::PLIWS const* lhs_pli, model::PLIWS const* rhs_pli,
47+
model::PLIWS const* joint_pli) {
48+
switch (afd_error_measure_) {
49+
case +AfdErrorMeasure::pdep:
50+
return 1 - afd_metric_calculator::AFDMetricCalculator::CalculatePdepMeasure(lhs_pli,
51+
joint_pli);
52+
case +AfdErrorMeasure::tau:
53+
return 1 - afd_metric_calculator::AFDMetricCalculator::CalculateTau(lhs_pli, rhs_pli,
54+
joint_pli);
55+
case +AfdErrorMeasure::mu_plus:
56+
return 1 - afd_metric_calculator::AFDMetricCalculator::CalculateMuPlus(lhs_pli, rhs_pli,
57+
joint_pli);
58+
case +AfdErrorMeasure::rho:
59+
return 1 - CalculateRhoMeasure(lhs_pli, joint_pli);
60+
default:
61+
return CalculateG1Error(lhs_pli, joint_pli, relation_.get()->GetNumTuplePairs());
62+
}
5763
}
5864

5965
} // namespace algos

src/core/algorithms/fd/tane/tane.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@ class Tane : public tane::TaneCommon {
1313
AfdErrorMeasure afd_error_measure_ = +AfdErrorMeasure::g1;
1414
void MakeExecuteOptsAvailableFDInternal() override final;
1515
config::ErrorType CalculateZeroAryFdError(ColumnData const* rhs) override;
16-
config::ErrorType CalculateFdError(tane::PLIVariantPtr lhs_pli, tane::PLIVariantPtr rhs_pli,
17-
tane::PLIVariantPtr joint_pli) override;
16+
config::ErrorType CalculateFdError(model::PLI const* lhs_pli, model::PLI const* rhs_pli,
17+
model::PLI const* joint_pli) override;
18+
config::ErrorType CalculateFdError(model::PLIWS const* lhs_pli, model::PLIWS const* rhs_pli,
19+
model::PLIWS const* joint_pli) override;
1820

1921
public:
2022
Tane();

src/core/algorithms/fd/tane/tane_common.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <iomanip>
55
#include <list>
66
#include <memory>
7+
#include <variant>
78

89
#include "core/algorithms/fd/pli_based_fd_algorithm.h"
910
#include "core/algorithms/fd/tane/model/lattice_level.h"
@@ -126,16 +127,18 @@ void TaneCommon::ComputeDependencies(model::LatticeLevel* level) {
126127
if (!a_candidates[a_index]) {
127128
continue;
128129
}
129-
std::variant<model::PLI const*, model::PLIWS const*> x_pli, a_pli;
130+
131+
config::ErrorType error;
130132
if (use_pliws_) {
131-
x_pli = x_vertex->GetPositionListIndexWithSingletons();
132-
a_pli = relation_->GetColumnData(a_index).GetPLWSIndex();
133+
model::PLIWS const* x_pli = x_vertex->GetPositionListIndexWithSingletons();
134+
model::PLIWS const* a_pli = relation_->GetColumnData(a_index).GetPLWSIndex();
135+
error = CalculateFdError(x_pli, a_pli, std::get<model::PLIWS const*>(xa_pli));
133136
} else {
134-
x_pli = x_vertex->GetPositionListIndex();
135-
a_pli = relation_->GetColumnData(a_index).GetPositionListIndex();
137+
model::PLI const* x_pli = x_vertex->GetPositionListIndex();
138+
model::PLI const* a_pli = relation_->GetColumnData(a_index).GetPositionListIndex();
139+
error = CalculateFdError(x_pli, a_pli, std::get<model::PLI const*>(xa_pli));
136140
}
137141
// Check X -> A
138-
config::ErrorType error = CalculateFdError(x_pli, a_pli, xa_pli);
139142
if (error <= max_fd_error_) {
140143
Column const* rhs = schema->GetColumns()[a_index].get();
141144

src/core/algorithms/fd/tane/tane_common.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
#include "core/model/table/position_list_index.h"
99

1010
namespace algos::tane {
11-
using PLIVariantPtr = std::variant<model::PLI const*, model::PLIWS const*>;
1211

1312
class TaneCommon : public PliBasedFDAlgorithm {
1413
protected:
@@ -23,9 +22,12 @@ class TaneCommon : public PliBasedFDAlgorithm {
2322
void ComputeDependencies(model::LatticeLevel* level);
2423
unsigned long long ExecuteInternal() final;
2524
virtual config::ErrorType CalculateZeroAryFdError(ColumnData const* rhs) = 0;
26-
virtual config::ErrorType CalculateFdError(PLIVariantPtr lhs_pli,
27-
[[maybe_unused]] PLIVariantPtr rhs_pli,
28-
PLIVariantPtr joint_pli) = 0;
25+
virtual config::ErrorType CalculateFdError(model::PLI const* lhs_pli,
26+
[[maybe_unused]] model::PLI const* rhs_pli,
27+
model::PLI const* joint_pli) = 0;
28+
virtual config::ErrorType CalculateFdError(model::PLIWS const* lhs_pli,
29+
[[maybe_unused]] model::PLIWS const* rhs_pli,
30+
model::PLIWS const* joint_pli) = 0;
2931
static double CalculateUccError(model::PositionListIndex const* pli,
3032
ColumnLayoutRelationData const* relation_data);
3133
void RegisterAndCountFd(Vertical lhs, Column const* rhs);

src/core/config/use_pliws/option.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
#include "core/config/use_pliws/option.h"
22

3-
#include <limits>
4-
53
#include "core/config/names_and_descriptions.h"
64

75
namespace config {

src/tests/unit/test_afd_metric_calculator.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ struct AFDMetricCalculatorParams {
2121
{kLhsIndices, std::move(lhs_indices)},
2222
{kRhsIndices, std::move(rhs_indices)},
2323
{kEqualNulls, true},
24-
{kMetric, metric},
25-
{"use_pliws", use_pliws}}),
24+
{kUsePliws, use_pliws},
25+
{kMetric, metric}}),
2626
expected(expected) {}
2727
};
2828

src/tests/unit/test_pfdtane.cpp

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,12 @@ struct PFDTaneMiningParams {
2121
unsigned int result_hash;
2222

2323
PFDTaneMiningParams(unsigned int result_hash, config::ErrorType error,
24-
algos::PfdErrorMeasure error_measure, CSVConfig const& csv_config)
24+
algos::PfdErrorMeasure error_measure, CSVConfig const& csv_config,
25+
bool use_pliws = false)
2526
: params({{onam::kCsvConfig, csv_config},
2627
{onam::kError, error},
27-
{onam::kPfdErrorMeasure, error_measure}}),
28+
{onam::kPfdErrorMeasure, error_measure},
29+
{onam::kUsePliws, use_pliws}}),
2830
result_hash(result_hash) {}
2931
};
3032

@@ -63,19 +65,37 @@ TEST_P(TestPFDTaneValidation, ErrorCalculationTest) {
6365
INSTANTIATE_TEST_SUITE_P(
6466
PFDTaneTestMiningSuite, TestPFDTaneMining,
6567
::testing::Values(
66-
PFDTaneMiningParams(44381, 0.3, +algos::PfdErrorMeasure::per_value, kTestFD),
67-
PFDTaneMiningParams(19266, 0.1, +algos::PfdErrorMeasure::per_value, kIris),
68-
PFDTaneMiningParams(10695, 0.01, +algos::PfdErrorMeasure::per_value, kIris),
69-
PFDTaneMiningParams(44088, 0.1, +algos::PfdErrorMeasure::per_value, kNeighbors10k),
70-
PFDTaneMiningParams(41837, 0.01, +algos::PfdErrorMeasure::per_value, kNeighbors10k)
71-
));
68+
PFDTaneMiningParams(44381, 0.3, +algos::PfdErrorMeasure::per_value, kTestFD),
69+
PFDTaneMiningParams(44381, 0.3, +algos::PfdErrorMeasure::per_value, kTestFD, true),
70+
PFDTaneMiningParams(19266, 0.1, +algos::PfdErrorMeasure::per_value, kIris),
71+
PFDTaneMiningParams(19266, 0.1, +algos::PfdErrorMeasure::per_value, kIris, true),
72+
PFDTaneMiningParams(10695, 0.01, +algos::PfdErrorMeasure::per_value, kIris),
73+
PFDTaneMiningParams(10695, 0.01, +algos::PfdErrorMeasure::per_value, kIris, true),
74+
PFDTaneMiningParams(44088, 0.1, +algos::PfdErrorMeasure::per_value, kNeighbors10k),
75+
PFDTaneMiningParams(44088, 0.1, +algos::PfdErrorMeasure::per_value, kNeighbors10k,
76+
true),
77+
PFDTaneMiningParams(41837, 0.01, +algos::PfdErrorMeasure::per_value, kNeighbors10k),
78+
PFDTaneMiningParams(41837, 0.01, +algos::PfdErrorMeasure::per_value, kNeighbors10k,
79+
true)));
7280

7381
INSTANTIATE_TEST_SUITE_P(
7482
PFDTaneTestValidationSuite, TestPFDTaneValidation,
75-
::testing::Values(
76-
PFDTaneValidationParams({{2, 3, 0.0625}, {4, 5, 0.333333}, {3, 2, 0.291666}, {0, 1, 0.75},
77-
{1, 0, 0.0}, {4, 3, 0.099999}, {1, 5, 0.416666}, {5, 1, 0.0}}, +algos::PfdErrorMeasure::per_value, kTestFD),
78-
PFDTaneValidationParams({{2, 3, 0.083333}, {4, 5, 0.333333}, {3, 2, 0.5}, {0, 1, 0.75},
79-
{1, 0, 0.0}, {4, 3, 0.083333}, {1, 5, 0.416666}, {5, 1, 0.0}}, +algos::PfdErrorMeasure::per_tuple, kTestFD)
80-
));
83+
::testing::Values(PFDTaneValidationParams({{2, 3, 0.0625},
84+
{4, 5, 0.333333},
85+
{3, 2, 0.291666},
86+
{0, 1, 0.75},
87+
{1, 0, 0.0},
88+
{4, 3, 0.099999},
89+
{1, 5, 0.416666},
90+
{5, 1, 0.0}},
91+
+algos::PfdErrorMeasure::per_value, kTestFD),
92+
PFDTaneValidationParams({{2, 3, 0.083333},
93+
{4, 5, 0.333333},
94+
{3, 2, 0.5},
95+
{0, 1, 0.75},
96+
{1, 0, 0.0},
97+
{4, 3, 0.083333},
98+
{1, 5, 0.416666},
99+
{5, 1, 0.0}},
100+
+algos::PfdErrorMeasure::per_tuple, kTestFD)));
81101
} // namespace tests

0 commit comments

Comments
 (0)