Skip to content

Commit f6c5724

Browse files
committed
Add PLI variant into the Tane
1 parent 5d1a511 commit f6c5724

9 files changed

Lines changed: 88 additions & 22 deletions

File tree

src/core/algorithms/fd/tane/afd_measures.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ config::ErrorType CalculateZeroAryG1(ColumnData const* rhs, unsigned long long n
88
static_cast<config::ErrorType>(num_tuple_pairs);
99
}
1010

11-
config::ErrorType CalculateG1Error(model::PLIWS const* lhs_pli, model::PLIWS const* joint_pli,
11+
config::ErrorType CalculateG1Error(model::PLI const* lhs_pli, model::PLI const* joint_pli,
1212
unsigned long long num_tuple_pairs) {
1313
return static_cast<config::ErrorType>((lhs_pli->GetNepAsLong() - joint_pli->GetNepAsLong()) /
1414
static_cast<config::ErrorType>(num_tuple_pairs));
@@ -102,7 +102,7 @@ config::ErrorType CalculateMuPlusMeasure(model::PLIWS const* x_pli, model::PLIWS
102102
return mu_plus;
103103
}
104104

105-
config::ErrorType CalculateRhoMeasure(model::PLIWS const* x_pli, model::PLIWS const* xa_pli) {
105+
config::ErrorType CalculateRhoMeasure(model::PLI const* x_pli, model::PLI const* xa_pli) {
106106
auto calculate_dom = [](model::PositionListIndex const* pli) {
107107
auto index = pli->GetIndex();
108108
size_t dom = index.size();

src/core/algorithms/fd/tane/afd_measures.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
namespace algos {
88
config::ErrorType CalculateZeroAryG1(ColumnData const* rhs, unsigned long long num_tuple_pairs);
99

10-
config::ErrorType CalculateG1Error(model::PLIWS const* lhs_pli, model::PLIWS const* joint_pli,
10+
config::ErrorType CalculateG1Error(model::PLI const* lhs_pli, model::PLI const* joint_pli,
1111
unsigned long long num_tuple_pairs);
1212

1313
config::ErrorType PdepSelf(model::PLI const* x_pli);
@@ -20,5 +20,5 @@ config::ErrorType CalculateTauMeasure(model::PLIWS const* x_pli, model::PLIWS co
2020
config::ErrorType CalculateMuPlusMeasure(model::PLIWS const* x_pli, model::PLIWS const* a_pli,
2121
model::PLIWS const* xa_pli);
2222

23-
config::ErrorType CalculateRhoMeasure(model::PLIWS const* x_pli, model::PLIWS const* xa_pli);
23+
config::ErrorType CalculateRhoMeasure(model::PLI const* x_pli, model::PLI const* xa_pli);
2424
} // namespace algos

src/core/algorithms/fd/tane/pfdtane.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,12 @@ config::ErrorType PFDTane::CalculateFdError(model::PLIWS const* lhs_pli,
3333
return CalculatePFDError(lhs_pli, joint_pli, pfd_error_measure_);
3434
}
3535

36+
config::ErrorType PFDTane::CalculateFdError(model::PLI const* lhs_pli,
37+
[[maybe_unused]] model::PLI const* rhs_pli,
38+
model::PLI const* joint_pli) {
39+
return CalculatePFDError(lhs_pli, joint_pli, pfd_error_measure_);
40+
}
41+
3642
config::ErrorType PFDTane::CalculateZeroAryPFDError(ColumnData const* rhs) {
3743
std::size_t max = 1;
3844
model::PositionListIndex const* x_pli = rhs->GetPositionListIndex();

src/core/algorithms/fd/tane/pfdtane.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,11 @@ class PFDTane : public tane::TaneCommon {
1515
void RegisterOptions();
1616
void MakeExecuteOptsAvailableFDInternal() final;
1717
config::ErrorType CalculateZeroAryFdError(ColumnData const* rhs) override;
18-
config::ErrorType CalculateFdError(model::PLIWS const* lhs_pli, model::PLIWS const* rhs_pli,
18+
config::ErrorType CalculateFdError(model::PLI const* lhs_pli,
19+
[[maybe_unused]] model::PLI const* rhs_pli,
20+
model::PLI const* joint_pli) override;
21+
config::ErrorType CalculateFdError(model::PLIWS const* lhs_pli,
22+
[[maybe_unused]] model::PLIWS const* rhs_pli,
1923
model::PLIWS const* joint_pli) override;
2024

2125
public:

src/core/algorithms/fd/tane/tane.cpp

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,27 @@ config::ErrorType Tane::CalculateZeroAryFdError(ColumnData const* rhs) {
2424
return 1;
2525
}
2626

27-
config::ErrorType Tane::CalculateFdError(model::PLIWithSingletons const* lhs_pli,
28-
model::PLIWithSingletons const* rhs_pli,
29-
model::PLIWithSingletons const* joint_pli) {
27+
config::ErrorType Tane::CalculateFdError(model::PLI const* lhs_pli, model::PLI const* rhs_pli,
28+
model::PLI const* joint_pli) {
29+
switch (afd_error_measure_) {
30+
case AfdErrorMeasure::kPdep:
31+
return 1 - afd_metric_calculator::AFDMetricCalculator::CalculatePdepMeasure(lhs_pli,
32+
joint_pli);
33+
case AfdErrorMeasure::kTau:
34+
return 1 - afd_metric_calculator::AFDMetricCalculator::CalculateTau(lhs_pli, rhs_pli,
35+
joint_pli);
36+
case AfdErrorMeasure::kMuPlus:
37+
return 1 - afd_metric_calculator::AFDMetricCalculator::CalculateMuPlus(lhs_pli, rhs_pli,
38+
joint_pli);
39+
case AfdErrorMeasure::kRho:
40+
return 1 - CalculateRhoMeasure(lhs_pli, joint_pli);
41+
default:
42+
return CalculateG1Error(lhs_pli, joint_pli, relation_.get()->GetNumTuplePairs());
43+
}
44+
}
45+
46+
config::ErrorType Tane::CalculateFdError(model::PLIWS const* lhs_pli, model::PLIWS const* rhs_pli,
47+
model::PLIWS const* joint_pli) {
3048
switch (afd_error_measure_) {
3149
case AfdErrorMeasure::kPdep:
3250
return 1 - afd_metric_calculator::AFDMetricCalculator::CalculatePdepMeasure(lhs_pli,

src/core/algorithms/fd/tane/tane.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@ class Tane : public tane::TaneCommon {
1313
AfdErrorMeasure afd_error_measure_ = AfdErrorMeasure::kG1;
1414
void MakeExecuteOptsAvailableFDInternal() override final;
1515
config::ErrorType CalculateZeroAryFdError(ColumnData const* rhs) override;
16-
config::ErrorType CalculateFdError(model::PLIWithSingletons const* lhs_pli,
17-
model::PLIWithSingletons const* rhs_pli,
18-
model::PLIWithSingletons const* joint_pli) override;
16+
config::ErrorType CalculateFdError(model::PLI const* lhs_pli, model::PLI const* rhs_pli,
17+
model::PLI const* joint_pli) override;
18+
config::ErrorType CalculateFdError(model::PLIWS const* lhs_pli, model::PLIWS const* rhs_pli,
19+
model::PLIWS const* joint_pli) override;
1920

2021
public:
2122
Tane();

src/core/algorithms/fd/tane/tane_common.cpp

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,14 @@
44
#include <iomanip>
55
#include <list>
66
#include <memory>
7+
#include <variant>
78

89
#include "core/algorithms/fd/pli_based_fd_algorithm.h"
910
#include "core/algorithms/fd/tane/model/lattice_level.h"
1011
#include "core/algorithms/fd/tane/model/lattice_vertex.h"
1112
#include "core/config/error/option.h"
13+
#include "core/config/option.h"
14+
#include "core/config/use_pliws/option.h"
1215
#include "core/model/table/column_data.h"
1316
#include "core/model/table/column_layout_relation_data.h"
1417
#include "core/model/table/relational_schema.h"
@@ -21,6 +24,7 @@ namespace tane {
2124

2225
TaneCommon::TaneCommon() : PliBasedFDAlgorithm() {
2326
RegisterOption(config::kErrorOpt(&max_ucc_error_));
27+
RegisterOption(config::kUsePliwsOpt(&use_pliws_));
2428
}
2529

2630
double TaneCommon::CalculateUccError(model::PositionListIndex const* pli,
@@ -93,14 +97,27 @@ void TaneCommon::ComputeDependencies(model::LatticeLevel* level) {
9397
Vertical xa = xa_vertex->GetVertical();
9498
// Calculate XA PLI
9599
if (xa_vertex->GetPositionListIndex() == nullptr) {
96-
auto parent_pli_1 = xa_vertex->GetParents()[0]->GetPositionListIndexWithSingletons();
97-
auto parent_pli_2 = xa_vertex->GetParents()[1]->GetPositionListIndexWithSingletons();
98-
xa_vertex->AcquirePLIWithSingletons(parent_pli_1->Intersect(parent_pli_2));
100+
if (use_pliws_) {
101+
auto parent_pli_1 =
102+
xa_vertex->GetParents()[0]->GetPositionListIndexWithSingletons();
103+
auto parent_pli_2 =
104+
xa_vertex->GetParents()[1]->GetPositionListIndexWithSingletons();
105+
xa_vertex->AcquirePLIWithSingletons(parent_pli_1->Intersect(parent_pli_2));
106+
} else {
107+
auto parent_pli_1 = xa_vertex->GetParents()[0]->GetPositionListIndex();
108+
auto parent_pli_2 = xa_vertex->GetParents()[1]->GetPositionListIndex();
109+
xa_vertex->AcquirePositionListIndex(parent_pli_1->Intersect(parent_pli_2));
110+
}
99111
}
100112

101113
dynamic_bitset<> xa_indices = xa.GetColumnIndices();
102114
dynamic_bitset<> a_candidates = xa_vertex->GetRhsCandidates();
103-
auto xa_pli = xa_vertex->GetPositionListIndexWithSingletons();
115+
std::variant<model::PLI const*, model::PLIWS const*> xa_pli;
116+
if (use_pliws_) {
117+
xa_pli = xa_vertex->GetPositionListIndexWithSingletons();
118+
} else {
119+
xa_pli = xa_vertex->GetPositionListIndex();
120+
}
104121
for (auto const& x_vertex : xa_vertex->GetParents()) {
105122
Vertical const& lhs = x_vertex->GetVertical();
106123

@@ -110,10 +127,18 @@ void TaneCommon::ComputeDependencies(model::LatticeLevel* level) {
110127
if (!a_candidates[a_index]) {
111128
continue;
112129
}
113-
auto x_pli = x_vertex->GetPositionListIndexWithSingletons();
114-
auto a_pli = relation_->GetColumnData(a_index).GetPLWSIndex();
130+
131+
config::ErrorType error;
132+
if (use_pliws_) {
133+
model::PLIWS const* x_pli = x_vertex->GetPositionListIndexWithSingletons();
134+
model::PLIWS const* a_pli = relation_->GetColumnData(a_index).GetPLWSIndex();
135+
error = CalculateFdError(x_pli, a_pli, std::get<model::PLIWS const*>(xa_pli));
136+
} else {
137+
model::PLI const* x_pli = x_vertex->GetPositionListIndex();
138+
model::PLI const* a_pli = relation_->GetColumnData(a_index).GetPositionListIndex();
139+
error = CalculateFdError(x_pli, a_pli, std::get<model::PLI const*>(xa_pli));
140+
}
115141
// Check X -> A
116-
config::ErrorType error = CalculateFdError(x_pli, a_pli, xa_pli);
117142
if (error <= max_fd_error_) {
118143
Column const* rhs = schema->GetColumns()[a_index].get();
119144

src/core/algorithms/fd/tane/tane_common.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ class TaneCommon : public PliBasedFDAlgorithm {
1313
protected:
1414
config::ErrorType max_fd_error_;
1515
config::ErrorType max_ucc_error_;
16+
bool use_pliws_ = false;
1617

1718
private:
1819
void ResetStateFd() final {}
@@ -21,6 +22,9 @@ class TaneCommon : public PliBasedFDAlgorithm {
2122
void ComputeDependencies(model::LatticeLevel* level);
2223
unsigned long long ExecuteInternal() final;
2324
virtual config::ErrorType CalculateZeroAryFdError(ColumnData const* rhs) = 0;
25+
virtual config::ErrorType CalculateFdError(model::PLI const* lhs_pli,
26+
[[maybe_unused]] model::PLI const* rhs_pli,
27+
model::PLI const* joint_pli) = 0;
2428
virtual config::ErrorType CalculateFdError(model::PLIWS const* lhs_pli,
2529
[[maybe_unused]] model::PLIWS const* rhs_pli,
2630
model::PLIWS const* joint_pli) = 0;

src/tests/unit/test_pfdtane.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,12 @@ struct PFDTaneMiningParams {
2121
unsigned int result_hash;
2222

2323
PFDTaneMiningParams(unsigned int result_hash, config::ErrorType error,
24-
algos::PfdErrorMeasure error_measure, CSVConfig const& csv_config)
24+
algos::PfdErrorMeasure error_measure, CSVConfig const& csv_config,
25+
bool use_pliws = false)
2526
: params({{onam::kCsvConfig, csv_config},
2627
{onam::kError, error},
27-
{onam::kPfdErrorMeasure, error_measure}}),
28+
{onam::kPfdErrorMeasure, error_measure},
29+
{onam::kUsePliws, use_pliws}}),
2830
result_hash(result_hash) {}
2931
};
3032

@@ -64,11 +66,17 @@ INSTANTIATE_TEST_SUITE_P(
6466
PFDTaneTestMiningSuite, TestPFDTaneMining,
6567
::testing::Values(
6668
PFDTaneMiningParams(44381, 0.3, algos::PfdErrorMeasure::kPerValue, kTestFD),
69+
PFDTaneMiningParams(44381, 0.3, algos::PfdErrorMeasure::kPerValue, kTestFD, true),
6770
PFDTaneMiningParams(19266, 0.1, algos::PfdErrorMeasure::kPerValue, kIris),
71+
PFDTaneMiningParams(19266, 0.1, algos::PfdErrorMeasure::kPerValue, kIris, true),
6872
PFDTaneMiningParams(10695, 0.01, algos::PfdErrorMeasure::kPerValue, kIris),
73+
PFDTaneMiningParams(10695, 0.01, algos::PfdErrorMeasure::kPerValue, kIris, true),
6974
PFDTaneMiningParams(44088, 0.1, algos::PfdErrorMeasure::kPerValue, kNeighbors10k),
70-
PFDTaneMiningParams(41837, 0.01, algos::PfdErrorMeasure::kPerValue,
71-
kNeighbors10k)));
75+
PFDTaneMiningParams(44088, 0.1, algos::PfdErrorMeasure::kPerValue, kNeighbors10k,
76+
true),
77+
PFDTaneMiningParams(41837, 0.01, algos::PfdErrorMeasure::kPerValue, kNeighbors10k),
78+
PFDTaneMiningParams(41837, 0.01, algos::PfdErrorMeasure::kPerValue, kNeighbors10k,
79+
true)));
7280

7381
INSTANTIATE_TEST_SUITE_P(
7482
PFDTaneTestValidationSuite, TestPFDTaneValidation,

0 commit comments

Comments
 (0)