-- |
-- Module:      Math.NumberTheory.Roots.Squares.Internal
-- Copyright:   (c) 2011 Daniel Fischer, 2016-2020 Andrew Lelechenko
-- Licence:     MIT
-- Maintainer:  Andrew Lelechenko <andrew.lelechenko@gmail.com>
--
-- Internal functions dealing with square roots. End-users should not import this module.

{-# LANGUAGE BangPatterns     #-}
{-# LANGUAGE CPP              #-}
{-# LANGUAGE MagicHash        #-}

module Math.NumberTheory.Roots.Squares.Internal
  ( karatsubaSqrt
  , isqrtA
  ) where

import Data.Bits (finiteBitSize, unsafeShiftL, unsafeShiftR, (.&.), (.|.))

import GHC.Exts (Int(..), Int#, isTrue#, int2Double#, sqrtDouble#, double2Int#, (<#))
#ifdef MIN_VERSION_integer_gmp
import GHC.Exts (uncheckedIShiftRA#, (*#), (-#))
import GHC.Integer.GMP.Internals (Integer(..), shiftLInteger, shiftRInteger, sizeofBigNat#)
import GHC.Integer.Logarithms (integerLog2#)
#define IS S#
#define IP Jp#
#define bigNatSize sizeofBigNat
#else
import GHC.Exts (uncheckedShiftRL#, word2Int#, minusWord#, timesWord#)
import GHC.Num.BigNat (bigNatSize#)
import GHC.Num.Integer (Integer(..), integerLog2#, integerShiftR#, integerShiftL#)
#endif

-- Find approximation to square root in 'Integer', then
-- find the integer square root by the integer variant
-- of Heron's method. Takes only a handful of steps
-- unless the input is really large.
{-# SPECIALISE isqrtA :: Integer -> Integer #-}
isqrtA :: Integral a => a -> a
isqrtA :: forall a. Integral a => a -> a
isqrtA a
0 = a
0
isqrtA a
n = a -> a -> a
forall a. Integral a => a -> a -> a
heron a
n (Integer -> a
forall a. Num a => Integer -> a
fromInteger (Integer -> a) -> (a -> Integer) -> a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> Integer
appSqrt (Integer -> Integer) -> (a -> Integer) -> a -> Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (a -> a) -> a -> a
forall a b. (a -> b) -> a -> b
$ a
n)

-- Heron's method for integers. First make one step to ensure
-- the value we're working on is @>= r@, then we have
-- @k == r@ iff @k <= step k@.
{-# SPECIALISE heron :: Integer -> Integer -> Integer #-}
heron :: Integral a => a -> a -> a
heron :: forall a. Integral a => a -> a -> a
heron a
n a
a = a -> a
go (a -> a
step a
a)
      where
        step :: a -> a
step a
k = (a
k a -> a -> a
forall a. Num a => a -> a -> a
+ a
n a -> a -> a
forall a. Integral a => a -> a -> a
`quot` a
k) a -> a -> a
forall a. Integral a => a -> a -> a
`quot` a
2
        go :: a -> a
go a
k
            | a
m a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
k     = a -> a
go a
m
            | Bool
otherwise = a
k
              where
                m :: a
m = a -> a
step a
k

-- Find a fairly good approximation to the square root.
-- At most one off for small Integers, about 48 bits should be correct
-- for large Integers.
appSqrt :: Integer -> Integer
appSqrt :: Integer -> Integer
appSqrt (IS Int#
i#) = Int# -> Integer
IS (Double# -> Int#
double2Int# (Double# -> Double#
sqrtDouble# (Int# -> Double#
int2Double# Int#
i#)))
appSqrt n :: Integer
n@(IP ByteArray#
bn#)
    | Int# -> Bool
isTrue# ((ByteArray# -> Int#
bigNatSize# ByteArray#
bn#) Int# -> Int# -> Int#
<# Int#
thresh#) =
          Double -> Integer
forall b. Integral b => Double -> b
forall a b. (RealFrac a, Integral b) => a -> b
floor (Double -> Double
forall a. Floating a => a -> a
sqrt (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Integer -> Double
forall a. Num a => Integer -> a
fromInteger Integer
n :: Double)
    | Bool
otherwise = case Integer -> Word#
integerLog2# Integer
n of
#ifdef MIN_VERSION_integer_gmp
                    l# -> case uncheckedIShiftRA# l# 1# -# 47# of
                            h# -> case shiftRInteger n (2# *# h#) of
                                    m -> case floor (sqrt $ fromInteger m :: Double) of
                                            r -> shiftLInteger r h#
#else
                    Word#
l# -> case Word# -> Int# -> Word#
uncheckedShiftRL# Word#
l# Int#
1# Word# -> Word# -> Word#
`minusWord#` Word#
47## of
                            Word#
h# -> case Integer -> Word# -> Integer
integerShiftR# Integer
n (Word#
2## Word# -> Word# -> Word#
`timesWord#` Word#
h#) of
                                    Integer
m -> case Double -> Integer
forall b. Integral b => Double -> b
forall a b. (RealFrac a, Integral b) => a -> b
floor (Double -> Double
forall a. Floating a => a -> a
sqrt (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Integer -> Double
forall a. Num a => Integer -> a
fromInteger Integer
m :: Double) of
                                            Integer
r -> Integer -> Word# -> Integer
integerShiftL# Integer
r Word#
h#
#endif
    where
        -- threshold for shifting vs. direct fromInteger
        -- we shift when we expect more than 256 bits
        thresh# :: Int#
        thresh# :: Int#
thresh# = if Word -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize (Word
0 :: Word) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
64 then Int#
5# else Int#
9#
-- There's already a check for negative in integerSquareRoot,
-- but integerSquareRoot' is exported directly too.
appSqrt Integer
_ = [Char] -> Integer
forall a. HasCallStack => [Char] -> a
error [Char]
"integerSquareRoot': negative argument"


-- Integer square root with remainder, using the Karatsuba Square Root
-- algorithm from
-- Paul Zimmermann. Karatsuba Square Root. [Research Report] RR-3805, 1999,
-- pp.8. <inria-00072854>

karatsubaSqrt :: Integer -> (Integer, Integer)
karatsubaSqrt :: Integer -> (Integer, Integer)
karatsubaSqrt Integer
0 = (Integer
0, Integer
0)
karatsubaSqrt Integer
n
    | Int
lgN Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
2300 =
        let s :: Integer
s = Integer -> Integer
forall a. Integral a => a -> a
isqrtA Integer
n in (Integer
s, Integer
n Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
s Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
s)
    | Bool
otherwise =
        if Int
lgN Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int
2 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
0 then
            Int -> (Integer, Integer, Integer, Integer) -> (Integer, Integer)
karatsubaStep Int
k (Int -> Integer -> (Integer, Integer, Integer, Integer)
karatsubaSplit Int
k Integer
n)
        else
            -- before we split n into 4 part we must ensure that the first part
            -- is at least 2^k/4, since this doesn't happen here we scale n by
            -- multiplying it by 4
            let n' :: Integer
n' = Integer
n Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
2
                (Integer
s, Integer
r) = Int -> (Integer, Integer, Integer, Integer) -> (Integer, Integer)
karatsubaStep Int
k (Int -> Integer -> (Integer, Integer, Integer, Integer)
karatsubaSplit Int
k Integer
n')
                r' :: Integer
r' | Integer
s Integer -> Integer -> Integer
forall a. Bits a => a -> a -> a
.&. Integer
1 Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
0 = Integer
r
                   | Bool
otherwise = Integer
r Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer -> Integer
double Integer
s Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1
            in  (Integer
s Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
1, Integer
r' Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
2)
  where
    k :: Int
k = Int
lgN Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
#ifdef MIN_VERSION_integer_gmp
    lgN = I# (integerLog2# n)
#else
    lgN :: Int
lgN = Int# -> Int
I# (Word# -> Int#
word2Int# (Integer -> Word#
integerLog2# Integer
n))
#endif

karatsubaStep :: Int -> (Integer, Integer, Integer, Integer) -> (Integer, Integer)
karatsubaStep :: Int -> (Integer, Integer, Integer, Integer) -> (Integer, Integer)
karatsubaStep Int
k (Integer
a3, Integer
a2, Integer
a1, Integer
a0)
    | Integer
r Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= Integer
0 = (Integer
s, Integer
r)
    | Bool
otherwise = (Integer
s Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1, Integer
r Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer -> Integer
double Integer
s Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1)
  where
    r :: Integer
r = Integer -> Integer -> Integer
forall a. Bits a => a -> a -> a
cat Integer
u Integer
a0 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
q Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
q
    s :: Integer
s = Integer
s' Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
k Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
q
    (Integer
q, Integer
u) = Integer -> Integer -> Integer
forall a. Bits a => a -> a -> a
cat Integer
r' Integer
a1 Integer -> Integer -> (Integer, Integer)
forall a. Integral a => a -> a -> (a, a)
`quotRem` Integer -> Integer
double Integer
s'
    (Integer
s', Integer
r') = Integer -> (Integer, Integer)
karatsubaSqrt (Integer -> Integer -> Integer
forall a. Bits a => a -> a -> a
cat Integer
a3 Integer
a2)
    cat :: a -> a -> a
cat a
x a
y = a
x a -> Int -> a
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
k a -> a -> a
forall a. Bits a => a -> a -> a
.|. a
y
    {-# INLINE cat #-}

karatsubaSplit :: Int -> Integer -> (Integer, Integer, Integer, Integer)
karatsubaSplit :: Int -> Integer -> (Integer, Integer, Integer, Integer)
karatsubaSplit Int
k Integer
n0 = (Integer
a3, Integer
a2, Integer
a1, Integer
a0)
  where
    a3 :: Integer
a3 = Integer
n3
    n3 :: Integer
n3 = Integer
n2 Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
k
    a2 :: Integer
a2 = Integer
n2 Integer -> Integer -> Integer
forall a. Bits a => a -> a -> a
.&. Integer
m
    n2 :: Integer
n2 = Integer
n1 Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
k
    a1 :: Integer
a1 = Integer
n1 Integer -> Integer -> Integer
forall a. Bits a => a -> a -> a
.&. Integer
m
    n1 :: Integer
n1 = Integer
n0 Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
k
    a0 :: Integer
a0 = Integer
n0 Integer -> Integer -> Integer
forall a. Bits a => a -> a -> a
.&. Integer
m
    m :: Integer
m = Integer
1 Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
k Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1

double :: Integer -> Integer
double :: Integer -> Integer
double Integer
x = Integer
x Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
1
{-# INLINE double #-}