Skip to content

Commit 8236882

Browse files
committed
trig: cleanup remaining workloads
1 parent 955888e commit 8236882

1 file changed

Lines changed: 182 additions & 95 deletions

File tree

tests/recipes/trig.rs

Lines changed: 182 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@ fn workload_symmetry_periodicity() -> Workload {
3131
let op = Workload::new(["sin", "cos", "tan"]);
3232
let var = Workload::new(["a"]);
3333

34-
let t_shift = Workload::new(["x", "(~ x)", "(+ PI x)", "(- PI x)", "(+ x x)"]).plug("x", &var);
34+
// shift argument to function
35+
// (and optionally) negate it
36+
let t_shift = Workload::new(["t", "(~ t)", "(+ PI t)", "(- PI t)", "(+ t t)"]).plug("t", &var);
3537
let t_simpl = Workload::new(["(op t)"]).plug("op", &op).plug("t", &t_shift);
3638
let t_neg = Workload::new(["(~ t)"]).plug("t", &t_simpl);
3739

@@ -44,47 +46,148 @@ fn workload_sum_of_squares() -> Workload {
4446
let op = Workload::new(["sin", "cos", "tan"]);
4547
let var = Workload::new(["a", "b"]);
4648

47-
let double_angle = Filter::Or(vec![
48-
Filter::Contains("(sin (+ ?a ?a))".parse().unwrap()),
49-
Filter::Contains("(cos (+ ?a ?a))".parse().unwrap()),
50-
Filter::Contains("(tan (+ ?a ?a))".parse().unwrap()),
49+
let is_double_angle = Filter::Or(vec![
50+
Filter::Contains("(sin (+ ?x ?x))".parse().unwrap()),
51+
Filter::Contains("(cos (+ ?x ?x))".parse().unwrap()),
52+
Filter::Contains("(tan (+ ?x ?x))".parse().unwrap()),
5153
]);
5254

53-
let t_simpl = Workload::new(["(op t)"]).plug("op", &op).plug("t", &var);
54-
let t_sqr = Workload::new(["(* t t)"]).plug("t", &t_simpl);
55+
// sum (or difference) of squares (of trig functions)
56+
let t_sqr = Workload::new(["(sqr t)"])
57+
.plug("t", &Workload::new(["(op v)"]))
58+
.plug("op", &op)
59+
.plug("v", &var);
5560
let t_sos = Workload::new(["(+ t t)", "(- t t)"]).plug("t", &t_sqr);
5661

5762
workload_symmetry_periodicity()
58-
.filter(Filter::Invert(Box::new(double_angle)))
63+
.filter(Filter::Invert(Box::new(is_double_angle)))
5964
.append(t_sos)
6065
}
6166

6267
/// Terms to prove co-angle identities, e.g., (cos x) => (sin (- (/ PI 2) x))
6368
fn workload_coangle() -> Workload {
64-
let op = Workload::new(["sin", "cos", "tan"]);
69+
let op = Workload::new(["sin", "cos"]);
6570
let var = Workload::new(["a", "b"]);
66-
let cnst = Workload::new([
67-
"0", "(/ PI 6)", "(/ PI 4)", "(/ PI 3)", "(/ PI 2)", "PI", "(* PI 2)",
68-
]);
71+
let cnst = Workload::new(["-2", "-1", "0", "1", "2"]);
6972

70-
let t_shift = Workload::new(["x", "(~ x)", "(+ PI x)", "(- PI x)", "(+ x x)"]).plug("x", &var);
73+
let t_shift = Workload::new(["t", "(- (/ PI 2) t)", "(+ (/ PI 2) t)", "(* 2 t)"]).plug("t", &var);
7174
let t_simpl = Workload::new(["(op t)"]).plug("op", &op).plug("t", &t_shift);
7275

7376
t_simpl.append(cnst)
7477
}
7578

7679
/// Terms to prove power reduction identities, e.g.,
7780
/// (* (cos x) (cos x)) => (1 - (sin x) * (sin x)).
78-
// fn workload_power_reduction() -> Workload {
79-
// let op = Workload::new(["sin", "cos", "tan"]);
80-
// let var = Workload::new(["a", "b"]);
81-
// let cnst = Workload::new([
82-
// "0", "(/ PI 6)", "(/ PI 4)", "(/ PI 3)", "(/ PI 2)", "PI", "(* PI 2)",
83-
// ]);
84-
85-
// let t_xform = Workload::new(["a", "(- a)", "(+ x)", "(+ (/ PI 2) x)", "(* 2 x)"]).plug("x", &var);
86-
87-
// }
81+
fn workload_power_reduction() -> Workload {
82+
let op = Workload::new(["sin", "cos"]);
83+
let var = Workload::new(["a"]);
84+
let cnst = Workload::new(["-2", "-1", "0", "1", "2"]);
85+
86+
// squared trig functions with variable arguments
87+
let t_trig = Workload::new(["(op t)"]).plug("op", &op).plug("t", &var);
88+
let t_sqr = Workload::new(["(* t t)"]).plug("t", &t_trig);
89+
90+
// trig functions (with possibly shifted arguments, and shifted output)
91+
let t_xform = Workload::new(["t", "(- (/ PI 2) t)", "(+ (/ PI 2) t)", "(* 2 t)"]).plug("t", &var);
92+
let t_trig_xform = Workload::new(["(op t)"]).plug("op", &op).plug("t", &t_xform);
93+
let t_shift = Workload::new(["t", "(- 1 t)", "(+ 1 t)"]).plug("t", &t_trig_xform);
94+
95+
// merge and scale
96+
let t_prescale = t_shift.append(t_sqr);
97+
let t_scale = Workload::new(["t", "(/ t 2)"]).plug("t", &t_prescale);
98+
99+
t_scale.append(cnst)
100+
}
101+
102+
/// Terms to prove product-to-sum identities, e.g.,
103+
/// (* (cos x) (cos y)) => (/ (+ (cos (- x y)) (cos (+ x y))) 2).
104+
fn workload_product_to_sum() -> Workload {
105+
let op = Workload::new(["sin", "cos"]);
106+
let cnst = Workload::new(["-2", "-1", "0", "1", "2"]);
107+
108+
// filter for square terms
109+
let is_square = Filter::Or(vec![
110+
Filter::Contains("(* (cos ?x) (cos ?x))".parse().unwrap()),
111+
Filter::Contains("(* (sin ?x) (sin ?x))".parse().unwrap()),
112+
]);
113+
114+
// simple arguments to trig functions
115+
let t_simpl = Workload::new(["a", "b", "(+ a b)", "(- a b)"]);
116+
117+
// trig functions with variable arguments
118+
let t_2var = Workload::new(["(op t)"]).plug("op", &op).plug("t", &t_simpl);
119+
120+
// product of trig functions (no squares)
121+
let t_prod = Workload::new(["(* t1 t2)"])
122+
.plug("t1", &t_2var)
123+
.plug("t2", &t_2var)
124+
.filter(Filter::Invert(Box::new(is_square)));
125+
126+
// sum of trig functions
127+
let t_sum = Workload::new(["(+ t1 t2)", "(- t1 t2)"])
128+
.plug("t1", &t_2var)
129+
.plug("t2", &t_2var);
130+
131+
// merge and scale
132+
let t_prescale = t_sum.append(t_prod);
133+
let t_scale = Workload::new(["t", "(/ t 2)"]).plug("t", &t_prescale);
134+
135+
t_scale.append(cnst)
136+
}
137+
138+
/// Terms to prove sum-to-product identities, e.g.,
139+
/// (+ (cos x) (cos y)) => (* 2 (cos (/ (+ x y) 2)) (cos (/ (- x y) 2))).
140+
fn workload_sum_to_product() -> Workload {
141+
let op = Workload::new(["sin", "cos"]);
142+
let cnst = Workload::new(["-2", "-1", "0", "1", "2"]);
143+
144+
// filter for non-trivial trig terms
145+
let is_nontrivial = Filter::Or(vec![
146+
Filter::Contains("(cos (?op ?x ?y))".parse().unwrap()),
147+
Filter::Contains("(sin (?op ?x ?y))".parse().unwrap()),
148+
]);
149+
150+
// filter for difference of angles
151+
let is_diff = Filter::Or(vec![
152+
Filter::Contains("(cos (- ?x ?y))".parse().unwrap()),
153+
Filter::Contains("(sin (- ?x ?y))".parse().unwrap()),
154+
]);
155+
156+
// filter for doubling
157+
let is_double = Filter::Contains("(+ ?x ?x)".parse().unwrap());
158+
159+
// filter for square terms
160+
let is_square = Filter::Or(vec![
161+
Filter::Contains("(* (cos ?x) (cos ?x))".parse().unwrap()),
162+
Filter::Contains("(* (sin ?x) (sin ?x))".parse().unwrap()),
163+
]);
164+
165+
// simple arguments to trig functions
166+
let t_simpl = Workload::new(["a", "b", "(+ a b)", "(- a b)"]);
167+
168+
// trig functions with variable arguments
169+
let t_2var = Workload::new(["(op t)"]).plug("op", &op).plug("t", &t_simpl);
170+
171+
// product of trig functions (no squares)
172+
let t_prod = Workload::new(["(* t1 t2)"])
173+
.plug("t", &t_2var)
174+
.plug("t1", &t_2var)
175+
.filter(Filter::Invert(Box::new(is_square)));
176+
177+
// sum-of-product terms
178+
let t_sop = Workload::new(["(+ t1 t2)", "(- t1 t2)"])
179+
.plug("t1", &t_prod)
180+
.plug("t2", &t_2var)
181+
.filter(Filter::Invert(Box::new(is_double)))
182+
.filter(Filter::Invert(Box::new(is_nontrivial)));
183+
184+
// remove difference of angles
185+
let t_2var_no_sub = t_2var.filter(Filter::Invert(Box::new(is_diff)));
186+
187+
t_2var_no_sub
188+
.append(t_sop)
189+
.append(cnst)
190+
}
88191

89192

90193
pub fn trig_rules() -> Ruleset<Trig> {
@@ -104,125 +207,109 @@ pub fn trig_rules() -> Ruleset<Trig> {
104207

105208
/////////////////////////////////////////////////////////////////
106209
// workload 1: constants
107-
println!("Starting 1");
210+
println!("starting 1: constants");
108211
let wkld_consts = workload_consts();
109212
let rules = run_fast_forwarding(wkld_consts.clone(), all.clone(), limits, limits);
110213
all.extend(rules.clone());
111214
new.extend(rules.clone());
112215

113216
/////////////////////////////////////////////////////////////////
114217
// workload 2: even/odd symmetry and periodicity
115-
println!("Starting 2");
218+
println!("starting 2: symmetry/periodicity");
116219
let wkld_sym_per = workload_symmetry_periodicity();
117220
let rules = run_fast_forwarding(wkld_sym_per.clone(), all.clone(), limits, limits);
118221
all.extend(rules.clone());
119222
new.extend(rules.clone());
120223

121224
/////////////////////////////////////////////////////////////////
122225
// workload 3: sum of squares
123-
println!("Starting 3");
226+
println!("starting 3: sum of squares");
124227
let wkld_sos = workload_sum_of_squares();
125228
let rules = run_fast_forwarding(wkld_sos, all.clone(), limits, limits);
126229
all.extend(rules.clone());
127230
new.extend(rules.clone());
128231

129232
/////////////////////////////////////////////////////////////////
130233
// workload 4: coangles
131-
println!("Starting 4");
234+
println!("starting 4: coangles");
132235
let wkld_coangle = workload_coangle();
133236
let rules = run_fast_forwarding(wkld_coangle, all.clone(), limits, limits);
134237
all.extend(rules.clone());
135238
new.extend(rules.clone());
136239

137240
/////////////////////////////////////////////////////////////////
138241
// workload 5: power reduction
139-
println!("Starting 5");
242+
println!("starting 5: power reduction");
243+
let wkld_power = workload_power_reduction();
244+
let rules = run_fast_forwarding(wkld_power, all.clone(), limits, limits);
245+
all.extend(rules.clone());
246+
new.extend(rules.clone());
140247

141248
/////////////////////////////////////////////////////////////////
142249
// workload 6: product-to-sum reduction
143-
println!("Starting 6");
250+
println!("starting 6: product-to-sum");
251+
let wkld_prod_sum = workload_product_to_sum();
252+
let rules = run_fast_forwarding(wkld_prod_sum, all.clone(), limits, limits);
253+
all.extend(rules.clone());
254+
new.extend(rules.clone());
144255

145256
/////////////////////////////////////////////////////////////////
146257
// workload 7: sum-to-product reduction
147-
println!("Starting 7");
258+
println!("starting 7: sum-to-product");
259+
let wkld_sum_prod = workload_sum_to_product();
260+
let rules = run_fast_forwarding(wkld_sum_prod, all.clone(), limits, limits);
261+
all.extend(rules.clone());
262+
new.extend(rules.clone());
148263

149264

150-
let non_square_filter = Filter::Invert(Box::new(Filter::Or(vec![
151-
Filter::Contains("(* (cos ?x) (cos ?x))".parse().unwrap()),
152-
Filter::Contains("(* (sin ?x) (sin ?x))".parse().unwrap()),
153-
])));
265+
new
266+
}
154267

155-
let two_x_filter = Filter::Invert(Box::new(Filter::Contains("(+ ?x ?x)".parse().unwrap())));
156268

157-
let trivial_trig_filter = Filter::Invert(Box::new(Filter::Or(vec![
158-
Filter::Contains("(cos (?op ?a ?b))".parse().unwrap()),
159-
Filter::Contains("(sin (?op ?a ?b))".parse().unwrap()),
160-
])));
269+
#[test]
270+
fn sandbox() {
161271

162-
let trig_no_sub_filter = Filter::Invert(Box::new(Filter::Or(vec![
163-
Filter::Contains("(cos (- ?a ?b))".parse().unwrap()),
164-
Filter::Contains("(sin (- ?a ?b))".parse().unwrap()),
272+
let no_trig_2x = Filter::Invert(Box::new(Filter::Or(vec![
273+
Filter::Contains("(sin (+ ?a ?a))".parse().unwrap()),
274+
Filter::Contains("(cos (+ ?a ?a))".parse().unwrap()),
275+
Filter::Contains("(tan (+ ?a ?a))".parse().unwrap()),
165276
])));
277+
let valid_trig = Filter::Invert(Box::new(Filter::Contains(
278+
"(tan (/ PI 2))".parse().unwrap(),
279+
)));
166280

167-
let t_ops = Workload::new(["sin", "cos"]);
281+
let t_ops = Workload::new(["sin", "cos", "tan"]);
282+
let consts = Workload::new([
283+
"0", "(/ PI 6)", "(/ PI 4)", "(/ PI 3)", "(/ PI 2)", "PI", "(* PI 2)",
284+
]);
168285
let app = Workload::new(["(op v)"]);
169-
let shift = Workload::new(["x", "(- 1 x)", "(+ 1 x)"]);
170-
let scale_down = Workload::new(["x", "(/ x 2)"]);
171-
let consts = Workload::new(["-2", "-1", "0", "1", "2"]);
286+
let trig_constants = app
287+
.clone()
288+
.plug("op", &t_ops)
289+
.plug("v", &consts)
290+
.filter(valid_trig);
172291

173-
let simple = app.clone().plug("op", &t_ops).plug(
292+
let simple_terms = app.clone().plug("op", &t_ops).plug(
174293
"v",
175-
&Workload::new(["a", "(- (/ PI 2) a)", "(+ (/ PI 2) a)", "(* 2 a)"]),
294+
&Workload::new(["a", "(~ a)", "(+ PI a)", "(- PI a)", "(+ a a)"]),
176295
);
177296

178-
let trivial_squares = Workload::new(["(sqr x)"])
297+
let neg_terms = Workload::new(["(~ x)"]).plug("x", &simple_terms);
298+
299+
let squares = Workload::new(["(sqr x)"])
179300
.plug("x", &app)
180301
.plug("op", &t_ops)
181-
.plug("v", &Workload::new(["a"]));
302+
.plug("v", &Workload::new(["a", "b"]));
182303

183-
let two_var = app
184-
.clone()
185-
.plug("op", &t_ops)
186-
.plug("v", &Workload::new(["a", "b", "(+ a b)", "(- a b)"]));
187-
let sum_two_vars = Workload::new(["(+ x y)", "(- x y)"])
188-
.plug("x", &two_var)
189-
.plug("y", &two_var);
190-
let prod_two_vars = Workload::new(["(* x y)"])
191-
.plug("x", &two_var)
192-
.plug("y", &two_var)
193-
.filter(non_square_filter);
194-
195-
let sum_of_prod = Workload::new(["(+ x y)", "(- x y)"])
196-
.plug("x", &prod_two_vars)
197-
.plug("y", &prod_two_vars)
198-
.filter(two_x_filter)
199-
.filter(trivial_trig_filter);
200-
201-
let shifted_simple = shift.clone().plug("x", &simple);
202-
let sum_and_prod = Workload::Append(vec![sum_two_vars.clone(), prod_two_vars.clone()]);
203-
let shifted_simple_sqrs = Workload::Append(vec![shifted_simple, trivial_squares]);
204-
let scaled_shifted_sqrs = scale_down.clone().plug("x", &shifted_simple_sqrs);
205-
206-
let scaled_sum_prod = scale_down.clone().plug("x", &sum_and_prod);
207-
208-
let two_var_no_sub = two_var.clone().filter(trig_no_sub_filter);
209-
210-
// Power reduction
211-
let wkld2 = Workload::Append(vec![scaled_shifted_sqrs, consts.clone()]);
212-
let rules2 = run_fast_forwarding(wkld2.clone(), all.clone(), limits, limits);
213-
all.extend(rules2.clone());
214-
new.extend(rules2.clone());
215-
216-
// Product-to-sum
217-
let wkld3 = Workload::Append(vec![scaled_sum_prod, consts.clone()]);
218-
let rules3 = run_fast_forwarding(wkld3.clone(), all.clone(), limits, limits);
219-
all.extend(rules3.clone());
220-
new.extend(rules3.clone());
221-
222-
// Sums
223-
let wkld4 = Workload::Append(vec![two_var_no_sub, sum_of_prod, consts.clone()]);
224-
let rules4 = run_fast_forwarding(wkld4.clone(), all.clone(), limits, limits);
225-
all.extend(rules4.clone());
226-
new.extend(rules4.clone());
227-
new
304+
let add = Workload::new(["(+ e e)", "(- e e)"]);
305+
306+
let sum_of_squares = add.plug("e", &squares);
307+
308+
let wkld1 = trig_constants;
309+
let wkld2 = Workload::Append(vec![wkld1.clone(), simple_terms, neg_terms]);
310+
let trimmed_wkld2 = wkld2.clone().filter(no_trig_2x);
311+
let wkld3 = Workload::Append(vec![trimmed_wkld2.clone(), sum_of_squares.clone()]);
312+
313+
let wkld = workload_sum_of_squares();
314+
assert_eq!(wkld.force().len(), wkld3.force().len());
228315
}

0 commit comments

Comments
 (0)