Skip to content

Commit b9d7ed5

Browse files
committed
Add PLI variant into the Tane
1 parent a009254 commit b9d7ed5

8 files changed

Lines changed: 86 additions & 40 deletions

File tree

src/core/algorithms/fd/tane/afd_measures.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ config::ErrorType CalculateZeroAryG1(ColumnData const* rhs, unsigned long long n
88
static_cast<config::ErrorType>(num_tuple_pairs);
99
}
1010

11-
config::ErrorType CalculateG1Error(model::PLIWS const* lhs_pli, model::PLIWS const* joint_pli,
11+
config::ErrorType CalculateG1Error(model::PLI const* lhs_pli, model::PLI const* joint_pli,
1212
unsigned long long num_tuple_pairs) {
1313
return static_cast<config::ErrorType>((lhs_pli->GetNepAsLong() - joint_pli->GetNepAsLong()) /
1414
static_cast<config::ErrorType>(num_tuple_pairs));
@@ -102,7 +102,7 @@ config::ErrorType CalculateMuPlusMeasure(model::PLIWS const* x_pli, model::PLIWS
102102
return mu_plus;
103103
}
104104

105-
config::ErrorType CalculateRhoMeasure(model::PLIWS const* x_pli, model::PLIWS const* xa_pli) {
105+
config::ErrorType CalculateRhoMeasure(model::PLI const* x_pli, model::PLI const* xa_pli) {
106106
auto calculate_dom = [](model::PositionListIndex const* pli) {
107107
auto index = pli->GetIndex();
108108
size_t dom = index.size();

src/core/algorithms/fd/tane/afd_measures.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
namespace algos {
88
config::ErrorType CalculateZeroAryG1(ColumnData const* rhs, unsigned long long num_tuple_pairs);
99

10-
config::ErrorType CalculateG1Error(model::PLIWS const* lhs_pli, model::PLIWS const* joint_pli,
10+
config::ErrorType CalculateG1Error(model::PLI const* lhs_pli, model::PLI const* joint_pli,
1111
unsigned long long num_tuple_pairs);
1212

1313
config::ErrorType PdepSelf(model::PLI const* x_pli);
@@ -20,5 +20,5 @@ config::ErrorType CalculateTauMeasure(model::PLIWS const* x_pli, model::PLIWS co
2020
config::ErrorType CalculateMuPlusMeasure(model::PLIWS const* x_pli, model::PLIWS const* a_pli,
2121
model::PLIWS const* xa_pli);
2222

23-
config::ErrorType CalculateRhoMeasure(model::PLIWS const* x_pli, model::PLIWS const* xa_pli);
23+
config::ErrorType CalculateRhoMeasure(model::PLI const* x_pli, model::PLI const* xa_pli);
2424
} // namespace algos

src/core/algorithms/fd/tane/pfdtane.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,21 @@ config::ErrorType PFDTane::CalculateZeroAryFdError(ColumnData const* rhs) {
2727
return CalculateZeroAryPFDError(rhs);
2828
}
2929

30-
config::ErrorType PFDTane::CalculateFdError(model::PLIWS const* lhs_pli,
31-
[[maybe_unused]] model::PLIWS const* rhs_pli,
32-
model::PLIWS const* joint_pli) {
33-
return CalculatePFDError(lhs_pli, joint_pli, pfd_error_measure_);
30+
config::ErrorType PFDTane::CalculateFdError(tane::PLIVariantPtr lhs_pli,
31+
tane::PLIVariantPtr rhs_pli,
32+
tane::PLIVariantPtr joint_pli) {
33+
auto visitor = [this](auto&& lhs_ptr, auto&& rhs_ptr, auto&& joint_ptr) -> config::ErrorType {
34+
using LhsType = std::decay_t<decltype(lhs_ptr)>;
35+
using RhsType = std::decay_t<decltype(rhs_ptr)>;
36+
using JointType = std::decay_t<decltype(joint_ptr)>;
37+
38+
if constexpr (std::is_same_v<LhsType, RhsType> && std::is_same_v<RhsType, JointType>) {
39+
return CalculatePFDError(lhs_ptr, joint_ptr, pfd_error_measure_);
40+
} else {
41+
std::terminate();
42+
}
43+
};
44+
return std::visit(visitor, lhs_pli, rhs_pli, joint_pli);
3445
}
3546

3647
config::ErrorType PFDTane::CalculateZeroAryPFDError(ColumnData const* rhs) {

src/core/algorithms/fd/tane/pfdtane.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ class PFDTane : public tane::TaneCommon {
1515
void RegisterOptions();
1616
void MakeExecuteOptsAvailableFDInternal() final;
1717
config::ErrorType CalculateZeroAryFdError(ColumnData const* rhs) override;
18-
config::ErrorType CalculateFdError(model::PLIWS const* lhs_pli, model::PLIWS const* rhs_pli,
19-
model::PLIWS const* joint_pli) override;
18+
config::ErrorType CalculateFdError(tane::PLIVariantPtr lhs_pli, tane::PLIVariantPtr rhs_pli,
19+
tane::PLIVariantPtr joint_pli) override;
2020

2121
public:
2222
PFDTane();

src/core/algorithms/fd/tane/tane.cpp

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,24 +24,36 @@ config::ErrorType Tane::CalculateZeroAryFdError(ColumnData const* rhs) {
2424
return 1;
2525
}
2626

27-
config::ErrorType Tane::CalculateFdError(model::PLIWithSingletons const* lhs_pli,
28-
model::PLIWithSingletons const* rhs_pli,
29-
model::PLIWithSingletons const* joint_pli) {
30-
switch (afd_error_measure_) {
31-
case +AfdErrorMeasure::pdep:
32-
return 1 - afd_metric_calculator::AFDMetricCalculator::CalculatePdepMeasure(lhs_pli,
33-
joint_pli);
34-
case +AfdErrorMeasure::tau:
35-
return 1 - afd_metric_calculator::AFDMetricCalculator::CalculateTau(lhs_pli, rhs_pli,
36-
joint_pli);
37-
case +AfdErrorMeasure::mu_plus:
38-
return 1 - afd_metric_calculator::AFDMetricCalculator::CalculateMuPlus(lhs_pli, rhs_pli,
39-
joint_pli);
40-
case +AfdErrorMeasure::rho:
41-
return 1 - CalculateRhoMeasure(lhs_pli, joint_pli);
42-
default:
43-
return CalculateG1Error(lhs_pli, joint_pli, relation_.get()->GetNumTuplePairs());
44-
}
27+
config::ErrorType Tane::CalculateFdError(tane::PLIVariantPtr lhs_pli, tane::PLIVariantPtr rhs_pli,
28+
tane::PLIVariantPtr joint_pli) {
29+
auto visitor = [this](auto&& lhs_pli, auto&& rhs_pli, auto&& joint_pli) -> config::ErrorType {
30+
using LhsType = std::decay_t<decltype(lhs_pli)>;
31+
using RhsType = std::decay_t<decltype(rhs_pli)>;
32+
using JointType = std::decay_t<decltype(joint_pli)>;
33+
34+
if constexpr (std::is_same_v<LhsType, RhsType> and std::is_same_v<RhsType, JointType>) {
35+
switch (afd_error_measure_) {
36+
case +AfdErrorMeasure::pdep:
37+
return 1 - afd_metric_calculator::AFDMetricCalculator::CalculatePdepMeasure(
38+
lhs_pli, joint_pli);
39+
case +AfdErrorMeasure::tau:
40+
return 1 - afd_metric_calculator::AFDMetricCalculator::CalculateTau(
41+
lhs_pli, rhs_pli, joint_pli);
42+
case +AfdErrorMeasure::mu_plus:
43+
return 1 - afd_metric_calculator::AFDMetricCalculator::CalculateMuPlus(
44+
lhs_pli, rhs_pli, joint_pli);
45+
case +AfdErrorMeasure::rho:
46+
return 1 - CalculateRhoMeasure(lhs_pli, joint_pli);
47+
default:
48+
return CalculateG1Error(lhs_pli, joint_pli,
49+
relation_.get()->GetNumTuplePairs());
50+
}
51+
} else {
52+
std::terminate();
53+
}
54+
};
55+
56+
return std::visit(visitor, lhs_pli, rhs_pli, joint_pli);
4557
}
4658

4759
} // namespace algos

src/core/algorithms/fd/tane/tane.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,8 @@ class Tane : public tane::TaneCommon {
1313
AfdErrorMeasure afd_error_measure_ = +AfdErrorMeasure::g1;
1414
void MakeExecuteOptsAvailableFDInternal() override final;
1515
config::ErrorType CalculateZeroAryFdError(ColumnData const* rhs) override;
16-
config::ErrorType CalculateFdError(model::PLIWithSingletons const* lhs_pli,
17-
model::PLIWithSingletons const* rhs_pli,
18-
model::PLIWithSingletons const* joint_pli) override;
16+
config::ErrorType CalculateFdError(tane::PLIVariantPtr lhs_pli, tane::PLIVariantPtr rhs_pli,
17+
tane::PLIVariantPtr joint_pli) override;
1918

2019
public:
2120
Tane();

src/core/algorithms/fd/tane/tane_common.cpp

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
#include "core/algorithms/fd/tane/model/lattice_level.h"
1010
#include "core/algorithms/fd/tane/model/lattice_vertex.h"
1111
#include "core/config/error/option.h"
12+
#include "core/config/option.h"
13+
#include "core/config/use_pliws/option.h"
1214
#include "core/model/table/column_data.h"
1315
#include "core/model/table/column_layout_relation_data.h"
1416
#include "core/model/table/relational_schema.h"
@@ -21,6 +23,7 @@ namespace tane {
2123

2224
TaneCommon::TaneCommon() : PliBasedFDAlgorithm() {
2325
RegisterOption(config::kErrorOpt(&max_ucc_error_));
26+
RegisterOption(config::kUsePliwsOpt(&use_pliws_));
2427
}
2528

2629
double TaneCommon::CalculateUccError(model::PositionListIndex const* pli,
@@ -93,14 +96,27 @@ void TaneCommon::ComputeDependencies(model::LatticeLevel* level) {
9396
Vertical xa = xa_vertex->GetVertical();
9497
// Calculate XA PLI
9598
if (xa_vertex->GetPositionListIndex() == nullptr) {
96-
auto parent_pli_1 = xa_vertex->GetParents()[0]->GetPositionListIndexWithSingletons();
97-
auto parent_pli_2 = xa_vertex->GetParents()[1]->GetPositionListIndexWithSingletons();
98-
xa_vertex->AcquirePLIWithSingletons(parent_pli_1->Intersect(parent_pli_2));
99+
if (use_pliws_) {
100+
auto parent_pli_1 =
101+
xa_vertex->GetParents()[0]->GetPositionListIndexWithSingletons();
102+
auto parent_pli_2 =
103+
xa_vertex->GetParents()[1]->GetPositionListIndexWithSingletons();
104+
xa_vertex->AcquirePLIWithSingletons(parent_pli_1->Intersect(parent_pli_2));
105+
} else {
106+
auto parent_pli_1 = xa_vertex->GetParents()[0]->GetPositionListIndex();
107+
auto parent_pli_2 = xa_vertex->GetParents()[1]->GetPositionListIndex();
108+
xa_vertex->AcquirePositionListIndex(parent_pli_1->Intersect(parent_pli_2));
109+
}
99110
}
100111

101112
dynamic_bitset<> xa_indices = xa.GetColumnIndices();
102113
dynamic_bitset<> a_candidates = xa_vertex->GetRhsCandidates();
103-
auto xa_pli = xa_vertex->GetPositionListIndexWithSingletons();
114+
std::variant<model::PLI const*, model::PLIWS const*> xa_pli;
115+
if (use_pliws_) {
116+
xa_pli = xa_vertex->GetPositionListIndexWithSingletons();
117+
} else {
118+
xa_pli = xa_vertex->GetPositionListIndex();
119+
}
104120
for (auto const& x_vertex : xa_vertex->GetParents()) {
105121
Vertical const& lhs = x_vertex->GetVertical();
106122

@@ -110,8 +126,14 @@ void TaneCommon::ComputeDependencies(model::LatticeLevel* level) {
110126
if (!a_candidates[a_index]) {
111127
continue;
112128
}
113-
auto x_pli = x_vertex->GetPositionListIndexWithSingletons();
114-
auto a_pli = relation_->GetColumnData(a_index).GetPLWSIndex();
129+
std::variant<model::PLI const*, model::PLIWS const*> x_pli, a_pli;
130+
if (use_pliws_) {
131+
x_pli = x_vertex->GetPositionListIndexWithSingletons();
132+
a_pli = relation_->GetColumnData(a_index).GetPLWSIndex();
133+
} else {
134+
x_pli = x_vertex->GetPositionListIndex();
135+
a_pli = relation_->GetColumnData(a_index).GetPositionListIndex();
136+
}
115137
// Check X -> A
116138
config::ErrorType error = CalculateFdError(x_pli, a_pli, xa_pli);
117139
if (error <= max_fd_error_) {

src/core/algorithms/fd/tane/tane_common.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@
88
#include "core/model/table/position_list_index.h"
99

1010
namespace algos::tane {
11+
using PLIVariantPtr = std::variant<model::PLI const*, model::PLIWS const*>;
1112

1213
class TaneCommon : public PliBasedFDAlgorithm {
1314
protected:
1415
config::ErrorType max_fd_error_;
1516
config::ErrorType max_ucc_error_;
17+
bool use_pliws_ = false;
1618

1719
private:
1820
void ResetStateFd() final {}
@@ -21,9 +23,9 @@ class TaneCommon : public PliBasedFDAlgorithm {
2123
void ComputeDependencies(model::LatticeLevel* level);
2224
unsigned long long ExecuteInternal() final;
2325
virtual config::ErrorType CalculateZeroAryFdError(ColumnData const* rhs) = 0;
24-
virtual config::ErrorType CalculateFdError(model::PLIWS const* lhs_pli,
25-
[[maybe_unused]] model::PLIWS const* rhs_pli,
26-
model::PLIWS const* joint_pli) = 0;
26+
virtual config::ErrorType CalculateFdError(PLIVariantPtr lhs_pli,
27+
[[maybe_unused]] PLIVariantPtr rhs_pli,
28+
PLIVariantPtr joint_pli) = 0;
2729
static double CalculateUccError(model::PositionListIndex const* pli,
2830
ColumnLayoutRelationData const* relation_data);
2931
void RegisterAndCountFd(Vertical lhs, Column const* rhs);

0 commit comments

Comments
 (0)