{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE MagicHash #-}
module Math.NumberTheory.Roots.Squares.Internal
( karatsubaSqrt
, isqrtA
) where
import Data.Bits (finiteBitSize, unsafeShiftL, unsafeShiftR, (.&.), (.|.))
import GHC.Exts (Int(..), Int#, isTrue#, int2Double#, sqrtDouble#, double2Int#, (<#))
#ifdef MIN_VERSION_integer_gmp
import GHC.Exts (uncheckedIShiftRA#, (*#), (-#))
import GHC.Integer.GMP.Internals (Integer(..), shiftLInteger, shiftRInteger, sizeofBigNat#)
import GHC.Integer.Logarithms (integerLog2#)
#define IS S#
#define IP Jp#
#define bigNatSize sizeofBigNat
#else
import GHC.Exts (uncheckedShiftRL#, word2Int#, minusWord#, timesWord#)
import GHC.Num.BigNat (bigNatSize#)
import GHC.Num.Integer (Integer(..), integerLog2#, integerShiftR#, integerShiftL#)
#endif
{-# SPECIALISE isqrtA :: Integer -> Integer #-}
isqrtA :: Integral a => a -> a
isqrtA :: forall a. Integral a => a -> a
isqrtA a
0 = a
0
isqrtA a
n = a -> a -> a
forall a. Integral a => a -> a -> a
heron a
n (Integer -> a
forall a. Num a => Integer -> a
fromInteger (Integer -> a) -> (a -> Integer) -> a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> Integer
appSqrt (Integer -> Integer) -> (a -> Integer) -> a -> Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (a -> a) -> a -> a
forall a b. (a -> b) -> a -> b
$ a
n)
{-# SPECIALISE heron :: Integer -> Integer -> Integer #-}
heron :: Integral a => a -> a -> a
heron :: forall a. Integral a => a -> a -> a
heron a
n a
a = a -> a
go (a -> a
step a
a)
where
step :: a -> a
step a
k = (a
k a -> a -> a
forall a. Num a => a -> a -> a
+ a
n a -> a -> a
forall a. Integral a => a -> a -> a
`quot` a
k) a -> a -> a
forall a. Integral a => a -> a -> a
`quot` a
2
go :: a -> a
go a
k
| a
m a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
k = a -> a
go a
m
| Bool
otherwise = a
k
where
m :: a
m = a -> a
step a
k
appSqrt :: Integer -> Integer
appSqrt :: Integer -> Integer
appSqrt (IS Int#
i#) = Int# -> Integer
IS (Double# -> Int#
double2Int# (Double# -> Double#
sqrtDouble# (Int# -> Double#
int2Double# Int#
i#)))
appSqrt n :: Integer
n@(IP ByteArray#
bn#)
| Int# -> Bool
isTrue# ((ByteArray# -> Int#
bigNatSize# ByteArray#
bn#) Int# -> Int# -> Int#
<# Int#
thresh#) =
Double -> Integer
forall b. Integral b => Double -> b
forall a b. (RealFrac a, Integral b) => a -> b
floor (Double -> Double
forall a. Floating a => a -> a
sqrt (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Integer -> Double
forall a. Num a => Integer -> a
fromInteger Integer
n :: Double)
| Bool
otherwise = case Integer -> Word#
integerLog2# Integer
n of
#ifdef MIN_VERSION_integer_gmp
l# -> case uncheckedIShiftRA# l# 1# -# 47# of
h# -> case shiftRInteger n (2# *# h#) of
m -> case floor (sqrt $ fromInteger m :: Double) of
r -> shiftLInteger r h#
#else
Word#
l# -> case Word# -> Int# -> Word#
uncheckedShiftRL# Word#
l# Int#
1# Word# -> Word# -> Word#
`minusWord#` Word#
47## of
Word#
h# -> case Integer -> Word# -> Integer
integerShiftR# Integer
n (Word#
2## Word# -> Word# -> Word#
`timesWord#` Word#
h#) of
Integer
m -> case Double -> Integer
forall b. Integral b => Double -> b
forall a b. (RealFrac a, Integral b) => a -> b
floor (Double -> Double
forall a. Floating a => a -> a
sqrt (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Integer -> Double
forall a. Num a => Integer -> a
fromInteger Integer
m :: Double) of
Integer
r -> Integer -> Word# -> Integer
integerShiftL# Integer
r Word#
h#
#endif
where
thresh# :: Int#
thresh# :: Int#
thresh# = if Word -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize (Word
0 :: Word) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
64 then Int#
5# else Int#
9#
appSqrt Integer
_ = [Char] -> Integer
forall a. HasCallStack => [Char] -> a
error [Char]
"integerSquareRoot': negative argument"
karatsubaSqrt :: Integer -> (Integer, Integer)
karatsubaSqrt :: Integer -> (Integer, Integer)
karatsubaSqrt Integer
0 = (Integer
0, Integer
0)
karatsubaSqrt Integer
n
| Int
lgN Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
2300 =
let s :: Integer
s = Integer -> Integer
forall a. Integral a => a -> a
isqrtA Integer
n in (Integer
s, Integer
n Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
s Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
s)
| Bool
otherwise =
if Int
lgN Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int
2 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
0 then
Int -> (Integer, Integer, Integer, Integer) -> (Integer, Integer)
karatsubaStep Int
k (Int -> Integer -> (Integer, Integer, Integer, Integer)
karatsubaSplit Int
k Integer
n)
else
let n' :: Integer
n' = Integer
n Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
2
(Integer
s, Integer
r) = Int -> (Integer, Integer, Integer, Integer) -> (Integer, Integer)
karatsubaStep Int
k (Int -> Integer -> (Integer, Integer, Integer, Integer)
karatsubaSplit Int
k Integer
n')
r' :: Integer
r' | Integer
s Integer -> Integer -> Integer
forall a. Bits a => a -> a -> a
.&. Integer
1 Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
0 = Integer
r
| Bool
otherwise = Integer
r Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer -> Integer
double Integer
s Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1
in (Integer
s Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
1, Integer
r' Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
2)
where
k :: Int
k = Int
lgN Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
#ifdef MIN_VERSION_integer_gmp
lgN = I# (integerLog2# n)
#else
lgN :: Int
lgN = Int# -> Int
I# (Word# -> Int#
word2Int# (Integer -> Word#
integerLog2# Integer
n))
#endif
karatsubaStep :: Int -> (Integer, Integer, Integer, Integer) -> (Integer, Integer)
karatsubaStep :: Int -> (Integer, Integer, Integer, Integer) -> (Integer, Integer)
karatsubaStep Int
k (Integer
a3, Integer
a2, Integer
a1, Integer
a0)
| Integer
r Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= Integer
0 = (Integer
s, Integer
r)
| Bool
otherwise = (Integer
s Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1, Integer
r Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer -> Integer
double Integer
s Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1)
where
r :: Integer
r = Integer -> Integer -> Integer
forall a. Bits a => a -> a -> a
cat Integer
u Integer
a0 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
q Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
q
s :: Integer
s = Integer
s' Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
k Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
q
(Integer
q, Integer
u) = Integer -> Integer -> Integer
forall a. Bits a => a -> a -> a
cat Integer
r' Integer
a1 Integer -> Integer -> (Integer, Integer)
forall a. Integral a => a -> a -> (a, a)
`quotRem` Integer -> Integer
double Integer
s'
(Integer
s', Integer
r') = Integer -> (Integer, Integer)
karatsubaSqrt (Integer -> Integer -> Integer
forall a. Bits a => a -> a -> a
cat Integer
a3 Integer
a2)
cat :: a -> a -> a
cat a
x a
y = a
x a -> Int -> a
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
k a -> a -> a
forall a. Bits a => a -> a -> a
.|. a
y
{-# INLINE cat #-}
karatsubaSplit :: Int -> Integer -> (Integer, Integer, Integer, Integer)
karatsubaSplit :: Int -> Integer -> (Integer, Integer, Integer, Integer)
karatsubaSplit Int
k Integer
n0 = (Integer
a3, Integer
a2, Integer
a1, Integer
a0)
where
a3 :: Integer
a3 = Integer
n3
n3 :: Integer
n3 = Integer
n2 Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
k
a2 :: Integer
a2 = Integer
n2 Integer -> Integer -> Integer
forall a. Bits a => a -> a -> a
.&. Integer
m
n2 :: Integer
n2 = Integer
n1 Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
k
a1 :: Integer
a1 = Integer
n1 Integer -> Integer -> Integer
forall a. Bits a => a -> a -> a
.&. Integer
m
n1 :: Integer
n1 = Integer
n0 Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
k
a0 :: Integer
a0 = Integer
n0 Integer -> Integer -> Integer
forall a. Bits a => a -> a -> a
.&. Integer
m
m :: Integer
m = Integer
1 Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
k Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1
double :: Integer -> Integer
double :: Integer -> Integer
double Integer
x = Integer
x Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
1
{-# INLINE double #-}