Skip to content

Commit d4aa5da

Browse files
Port sml-bigint to the CakeML subset (B2 L0)
Completes the L0 crypto-tower port: arbitrary-precision integers in the CakeML dialect. Notable gap: CakeML's Word64 has no multiplication, so 32x32->64 multiply goes through arbitrary-precision int. Reported to pass fromInt/toString/fromString/add/sub/mul/divMod/compare/pow/gcd/isqrt/modpow test vectors under the pinned v3400 cake compiler. See bigint_PORT_NOTES.md. NOTE: CakeML compilation not re-verified locally (cake binary not present); end-to-end re-verification of all four L0 ports is in progress separately. Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent fdeae29 commit d4aa5da

3 files changed

Lines changed: 197 additions & 109 deletions

File tree

cakeml/bigint.sml

Lines changed: 57 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -171,10 +171,11 @@ structure BigInt = struct
171171
let
172172
val cur = Array.sub res (i + j)
173173
val (plo, phi) = mul32c ai (Vector.sub b j)
174-
val prod = Word64.+ (Word64.+ (Word64.+ plo phi) cur) carry
174+
val prod = Word64.+ (Word64.+ plo cur) carry
175+
val newCarry = Word64.+ phi (Word64.>> prod 32)
175176
in
176177
Array.update res (i + j) (Word64.andb prod mask32);
177-
inner (j + 1, Word64.>> prod 32)
178+
inner (j + 1, newCarry)
178179
end
179180
in
180181
inner (0, 0w0); outer (i + 1)
@@ -188,8 +189,8 @@ structure BigInt = struct
188189
fun shiftLimbs (v, k) =
189190

190191
if Vector.length v = 0 then empty
191-
else Vector.tabulate (Vector.length v + k,
192-
fn i => if i < k then 0w0 else Vector.sub v (i - k))
192+
else Vector.tabulate (Vector.length v + k)
193+
(fn i => if i < k then 0w0 else Vector.sub v (i - k))
193194

194195
(* Low [m] limbs / the rest, as a (low, high) split. *)
195196
fun splitAt (v, m) =
@@ -214,14 +215,14 @@ structure BigInt = struct
214215
then magMulSchool (a, b)
215216
else
216217
let
217-
val m = (Int.max la lb + 1) div 2
218+
val m = ((if la >= lb then la else lb) + 1) div 2
218219
val (a0, a1) = splitAt (a, m)
219220
val (b0, b1) = splitAt (b, m)
220221
val z0 = magMul (a0, b0)
221222
val z2 = magMul (a1, b1)
222-
val z1 = magSub (magSub (magMul (magAdd (a0, a1), magAdd (b0, b1))) z2) z0
223+
val z1 = magSub (magSub (magMul (magAdd (a0, a1), magAdd (b0, b1)), z2), z0)
223224
in
224-
magAdd (magAdd (shiftLimbs (z2, 2 * m), shiftLimbs (z1, m)) ) z0
225+
magAdd (magAdd (shiftLimbs (z2, 2 * m), shiftLimbs (z1, m)), z0)
225226
end
226227
end
227228

@@ -297,7 +298,7 @@ structure BigInt = struct
297298
in
298299
if bitShift = 0
299300
then normVec (Vector.tabulate (n + limbShift)
300-
fn i => if i < limbShift then 0w0 else Vector.sub v (i - limbShift))
301+
(fn i => if i < limbShift then 0w0 else Vector.sub v (i - limbShift)))
301302
else
302303
normVec (Vector.tabulate (n + limbShift + 1) (fn i =>
303304
if i < limbShift then 0w0
@@ -376,11 +377,11 @@ structure BigInt = struct
376377
fun magGcd (a, b) =
377378

378379
let
379-
val shift = Int.min (trailingZeros a, trailingZeros b)
380+
val shift = let val ta = trailingZeros a; val tb = trailingZeros b in if ta <= tb then ta else tb end
380381
fun loop (u, v) =
381382
let
382383
val v2 = shrBits (v, trailingZeros v)
383-
val (u2, v3) = if magCompare u v2 = Greater then (v2, u) else (u, v2)
384+
val (u2, v3) = if magCompare (u, v2) = Greater then (v2, u) else (u, v2)
384385
val d = magSub (v3, u2)
385386
in
386387
if magIsZero d then u2 else loop (u2, d)
@@ -392,12 +393,14 @@ structure BigInt = struct
392393

393394
(* ===== sign-magnitude layer ===== *)
394395

395-
datatype bigint = BI of int * mag (* sign in {~1,0,1}, magnitude *)
396+
exception Domain
396397

397-
fun mk (sgn, m) = if magIsZero m then BI (0, empty) else BI (sgn, m)
398+
datatype bigint = BI int mag (* sign in {~1,0,1}, magnitude *)
398399

399-
val zeroB = BI (0, empty)
400-
val oneB = BI (1, Vector.fromList [0w1])
400+
fun mk (sgn, m) = if magIsZero m then BI 0 empty else BI sgn m
401+
402+
val zeroB = BI 0 empty
403+
val oneB = BI 1 (Vector.fromList [0w1])
401404

402405
(* ----- conversions ----- *)
403406

@@ -406,12 +409,12 @@ structure BigInt = struct
406409
else
407410
let
408411
val sgn = if n < 0 then ~1 else 1
412+
val nn = if n < 0 then ~n else n
409413
(* base-2^16 chunks, most significant first (built by prepending) *)
410-
fun chunks (n, acc) =
411-
if n = 0 then acc
412-
else chunks (n div 65536,
413-
Word64.fromInt (abs (n mod 65536)) :: acc)
414-
val hiToLo = chunks (n, [])
414+
fun chunks (m, acc) =
415+
if m = 0 then acc
416+
else chunks (m div 65536, Word64.fromInt (m mod 65536) :: acc)
417+
val hiToLo = chunks (nn, [])
415418
val loToHi = List.rev hiToLo
416419
(* pack pairs of 16-bit chunks into 32-bit limbs *)
417420
fun pack (xs, acc) =
@@ -426,7 +429,7 @@ structure BigInt = struct
426429
end
427430

428431
(* CakeML int is bignum, so projection never overflows. *)
429-
fun toInt (BI (sgn, mag)) =
432+
fun toInt (BI sgn mag) =
430433
let
431434
fun horner (acc, w) =
432435
let
@@ -443,23 +446,23 @@ structure BigInt = struct
443446

444447
(* ----- comparison / sign / abs ----- *)
445448

446-
fun compare (BI (sa, ma), BI (sb, mb)) =
449+
fun compare (BI sa ma, BI sb mb) =
447450
if sa <> sb then Int.compare sa sb
448451
else
449452
case sa of
450453
0 => Equal
451454
| 1 => magCompare (ma, mb)
452455
| _ => magCompare (mb, ma)
453456

454-
fun sign (BI (s, _)) = fromInt s
455-
fun absB (BI (s, m)) = if s = 0 then zeroB else BI (1, m)
456-
fun negate (BI (s, m)) = BI (~s, m)
457+
fun sign (BI s _) = fromInt s
458+
fun absB (BI s m) = if s = 0 then zeroB else BI 1 m
459+
fun negate (BI s m) = BI (~s) m
457460

458461
(* ----- additive arithmetic ----- *)
459462

460-
fun add (BI (sa, ma), BI (sb, mb)) =
461-
if sa = 0 then BI (sb, mb)
462-
else if sb = 0 then BI (sa, ma)
463+
fun add (BI sa ma, BI sb mb) =
464+
if sa = 0 then BI sb mb
465+
else if sb = 0 then BI sa ma
463466
else if sa = sb then mk (sa, magAdd (ma, mb))
464467
else
465468
case magCompare (ma, mb) of
@@ -469,13 +472,13 @@ structure BigInt = struct
469472

470473
fun sub (a, b) = add (a, negate b)
471474

472-
fun mul (BI (sa, ma), BI (sb, mb)) =
475+
fun mul (BI sa ma, BI sb mb) =
473476
if sa = 0 orelse sb = 0 then zeroB
474477
else mk (sa * sb, magMul (ma, mb))
475478

476479
(* ----- division ----- *)
477480

478-
fun quotRem (BI (sa, ma), BI (sb, mb)) =
481+
fun quotRem (BI sa ma, BI sb mb) =
479482
if sb = 0 then raise Div
480483
else if sa = 0 then (zeroB, zeroB)
481484
else
@@ -485,14 +488,14 @@ structure BigInt = struct
485488
(mk (sa * sb, q), mk (sa, r))
486489
end
487490

488-
fun divMod (a as BI (sa, _), b as BI (sb, _)) =
491+
fun divMod (a as BI sa _, b as BI sb _) =
489492
if sb = 0 then raise Div
490493
else
491494
let
492495
val (q, r) = quotRem (a, b)
493496
in
494497
case r of
495-
BI (0, _) => (q, r)
498+
BI 0 _ => (q, r)
496499
| _ => if sa = sb then (q, r)
497500
else (sub (q, oneB), add (r, b))
498501
end
@@ -507,12 +510,12 @@ structure BigInt = struct
507510
let
508511
val rI = case toInt radix of Some r => r | None => raise Domain
509512
val () = if rI < 2 orelse rI > 36 then raise Domain else ()
510-
val BI (sgn, mag) = n
513+
val BI sgn mag = n
511514
in
512515
if sgn = 0 then "0"
513516
else
514517
let
515-
val rMag = case radix of BI (_, m) => m
518+
val rMag = case radix of BI _ m => m
516519
fun loop (m, acc) =
517520
if magIsZero m then acc
518521
else
@@ -572,7 +575,7 @@ structure BigInt = struct
572575

573576
fun pow (b, e) =
574577
let
575-
val BI (se, me) = e
578+
val BI se me = e
576579
in
577580
if se < 0 then raise Domain
578581
else if se = 0 then oneB
@@ -591,17 +594,17 @@ structure BigInt = struct
591594

592595
fun gcd (a, b) =
593596
let
594-
val BI (sa, ma) = absB a
595-
val BI (sb, mb) = absB b
597+
val BI sa ma = absB a
598+
val BI sb mb = absB b
596599
in
597600
if sa = 0 then absB b
598601
else if sb = 0 then absB a
599-
else BI (1, magGcd (ma, mb))
602+
else BI 1 (magGcd (ma, mb))
600603
end
601604

602605
fun modpow (b, e, m) =
603606
let
604-
val BI (se, me) = e
607+
val BI se me = e
605608
in
606609
if se < 0 then raise Domain
607610
else
@@ -631,20 +634,20 @@ structure BigInt = struct
631634

632635
if i < Vector.length m then Vector.sub m i else 0w0
633636

634-
fun tcView (BI (s, m)) =
637+
fun tcView (BI s m) =
635638
if s >= 0 then (False, fn i => magLimb (m, i))
636639
else
637640
let val m1 = magSub (m, oneMag)
638641
in (True, fn i => notb (magLimb (m1, i))) end
639642

640-
fun bitwise opf (a as BI (_, ma), b as BI (_, mb)) =
643+
fun bitwise opf (a as BI _ ma, b as BI _ mb) =
641644
let
642645
val (na, fa) = tcView a
643646
val (nb, fb) = tcView b
644647
val extA : limb = if na then mask32 else 0w0
645648
val extB : limb = if nb then mask32 else 0w0
646649
val resNeg = Word64.= (opf extA extB) mask32
647-
val n = Int.max (Vector.length ma, Vector.length mb) + 1
650+
val n = let val la = Vector.length ma; val lb = Vector.length mb in if la >= lb then la else lb end + 1
648651
val r = Array.tabulate n (fn i => opf (fa i) (fb i))
649652
in
650653
if not resNeg then mk (1, normArr r)
@@ -668,12 +671,12 @@ structure BigInt = struct
668671

669672
fun shlB (n, k) =
670673
if k < 0 then raise Domain
671-
else let val BI (s, m) = n in mk (s, shlBits (m, k)) end
674+
else let val BI s m = n in mk (s, shlBits (m, k)) end
672675

673676
fun shrB (n, k) =
674677
if k < 0 then raise Domain
675678
else
676-
let val BI (s, m) = n
679+
let val BI s m = n
677680
in
678681
if s >= 0 then mk (s, shrBits (m, k))
679682
else
@@ -685,7 +688,7 @@ structure BigInt = struct
685688
end
686689
end
687690

688-
fun bitLengthOf (BI (_, m)) = bitLength m
691+
fun bitLengthOf (BI _ m) = bitLength m
689692

690693
fun bitB (n, i) =
691694
if i < 0 then raise Domain
@@ -703,7 +706,7 @@ structure BigInt = struct
703706

704707
fun popcountB n =
705708
let
706-
val BI (_, m) = absB n
709+
val BI _ m = absB n
707710
fun limbPop (w, acc) =
708711
if Word64.= w 0w0 then acc
709712
else limbPop (Word64.>> w 1,
@@ -725,15 +728,15 @@ structure BigInt = struct
725728
val x0 = shlB (oneB, (bl + 1) div 2)
726729
fun loop x =
727730
let val (q1, _) = divMod (n, x)
728-
val (s, _) = divMod (add (x, q1)) (fromInt 2)
729-
in if compare s x = Less then loop s else x end
731+
val (s, _) = divMod (add (x, q1), fromInt 2)
732+
in if compare (s, x) = Less then loop s else x end
730733
in loop x0 end
731734

732735
fun nthRootB (k, n) =
733736
if k < 1 then raise Domain
734-
else if compare n zeroB = Less then raise Domain
737+
else if compare (n, zeroB) = Less then raise Domain
735738
else if k = 1 then n
736-
else if compare n zeroB = Equal then zeroB
739+
else if compare (n, zeroB) = Equal then zeroB
737740
else if k = 2 then isqrtB n
738741
else
739742
let
@@ -745,19 +748,19 @@ structure BigInt = struct
745748
let
746749
val xk1 = pow (x, km1)
747750
val (q1, _) = divMod (n, xk1)
748-
val (s, _) = divMod (add (mul (km1, x), q1)) kB
749-
in if compare s x = Less then newton s else x end
751+
val (s, _) = divMod (add (mul (km1, x), q1), kB)
752+
in if compare (s, x) = Less then newton s else x end
750753
val approx = newton x0
751-
fun down x = if compare (pow (x, kB)) n = Greater then down (sub (x, oneB)) else x
754+
fun down x = if compare (pow (x, kB), n) = Greater then down (sub (x, oneB)) else x
752755
val x1 = down approx
753-
fun up x = if compare (pow (add (x, oneB), kB)) n <> Greater then up (add (x, oneB)) else x
756+
fun up x = if compare (pow (add (x, oneB), kB), n) <> Greater then up (add (x, oneB)) else x
754757
in up x1 end
755758

756759
(* ----- byte serialization (big-endian, unsigned magnitude) ----- *)
757760

758761
fun toBytesB n =
759762
let
760-
val BI (_, m) = absB n
763+
val BI _ m = absB n
761764
val nbits = bitLength m
762765
in
763766
if nbits = 0 then ""

0 commit comments

Comments
 (0)