001/**
002 * The MIT License (MIT)
003 *
004 * Copyright (c) 2015-2016 decimal4j (tools4j), Marco Terzer
005 *
006 * Permission is hereby granted, free of charge, to any person obtaining a copy
007 * of this software and associated documentation files (the "Software"), to deal
008 * in the Software without restriction, including without limitation the rights
009 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
010 * copies of the Software, and to permit persons to whom the Software is
011 * furnished to do so, subject to the following conditions:
012 *
013 * The above copyright notice and this permission notice shall be included in all
014 * copies or substantial portions of the Software.
015 *
016 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
017 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
018 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
019 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
020 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
021 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
022 * SOFTWARE.
023 */
024package org.decimal4j.arithmetic;
025
026import org.decimal4j.api.DecimalArithmetic;
027import org.decimal4j.scale.ScaleMetrics;
028import org.decimal4j.truncate.DecimalRounding;
029import org.decimal4j.truncate.TruncatedPart;
030
031/**
032 * Provides static methods to calculate square roots of Decimal numbers.
033 */
034final class Sqrt {
035        /**
036         * This mask is used to obtain the value of an int as if it were unsigned.
037         */
038        private static final long LONG_MASK = 0xffffffffL;
039
040        /**
041         * Calculates the square root of the specified long value truncating the
042         * result if necessary.
043         * 
044         * @param lValue
045         *            the long value
046         * @return <tt>round<sub>DOWN</sub>(lValue)</tt>
047         * @throws ArithmeticException
048         *             if {@code lValue < 0}
049         */
050        public static final long sqrtLong(long lValue) {
051                if (lValue < 0) {
052                        throw new ArithmeticException("Square root of a negative value: " + lValue);
053                }
054                // http://www.codecodex.com/wiki/Calculate_an_integer_square_root
055                if ((lValue & 0xfff0000000000000L) == 0) {
056                        return (long) StrictMath.sqrt(lValue);
057                }
058                final long result = (long) StrictMath.sqrt(2.0d * (lValue >>> 1));
059                return result * result - lValue > 0L ? result - 1 : result;
060        }
061
062        /**
063         * Calculates the square root of the specified long value rounding the
064         * result if necessary.
065         * 
066         * @param rounding
067         *            the rounding to apply if necessary
068         * @param lValue
069         *            the long value
070         * @return <tt>round(lValue)</tt>
071         * @throws ArithmeticException
072         *             if {@code lValue < 0}
073         */
074        public static final long sqrtLong(DecimalRounding rounding, long lValue) {
075                if (lValue < 0) {
076                        throw new ArithmeticException("Square root of a negative value: " + lValue);
077                }
078                // square root
079                // @see
080                // http://www.embedded.com/electronics-blogs/programmer-s-toolbox/4219659/Integer-Square-Roots
081                long rem = 0;
082                long root = 0;
083                final int zerosHalf = Long.numberOfLeadingZeros(lValue) >> 1;
084                long scaled = lValue << (zerosHalf << 1);
085                for (int i = zerosHalf; i < 32; i++) {
086                        root <<= 1;
087                        rem = ((rem << 2) + (scaled >>> 62));
088                        scaled <<= 2;
089                        root++;
090                        if (root <= rem) {
091                                rem -= root;
092                                root++;
093                        } else {
094                                root--;
095                        }
096                }
097                final long truncated = root >>> 1;
098                if (rem == 0 | rounding == DecimalRounding.DOWN | rounding == DecimalRounding.FLOOR) {
099                        return truncated;
100                }
101                return truncated + getRoundingIncrement(rounding, truncated, rem);
102        }
103
104        /**
105         * Calculates the square root of the specified unscaled decimal value
106         * truncating the result if necessary.
107         * 
108         * @param arith
109         *            the arithmetic associated with the value
110         * @param uDecimal
111         *            the unscaled decimal value
112         * @return <tt>round<sub>DOWN</sub>(uDecimal)</tt>
113         * @throws ArithmeticException
114         *             if {@code uDecimal < 0}
115         */
116        public static final long sqrt(DecimalArithmetic arith, long uDecimal) {
117                return sqrt(arith, DecimalRounding.DOWN, uDecimal);
118        }
119
120        /**
121         * Calculates the square root of the specified unscaled decimal value
122         * rounding the result if necessary.
123         * 
124         * @param arith
125         *            the arithmetic associated with the value
126         * @param rounding
127         *            the rounding to apply if necessary
128         * @param uDecimal
129         *            the unscaled decimal value
130         * @return <tt>round(uDecimal)</tt>
131         * @throws ArithmeticException
132         *             if {@code uDecimal < 0}
133         */
134        public static final long sqrt(DecimalArithmetic arith, DecimalRounding rounding, long uDecimal) {
135                if (uDecimal < 0) {
136                        throw new ArithmeticException("Square root of a negative value: " + arith.toString(uDecimal));
137                }
138                final ScaleMetrics scaleMetrics = arith.getScaleMetrics();
139
140                // multiply by scale factor into a 128bit integer
141                final int lFactor = (int) (uDecimal & LONG_MASK);
142                final int hFactor = (int) (uDecimal >>> 32);
143                long lScaled;
144                long hScaled;
145                long product;
146
147                product = scaleMetrics.mulloByScaleFactor(lFactor);
148                lScaled = product & LONG_MASK;
149                product = scaleMetrics.mulhiByScaleFactor(lFactor) + (product >>> 32);
150                hScaled = product >>> 32;
151                product = scaleMetrics.mulloByScaleFactor(hFactor) + (product & LONG_MASK);
152                lScaled |= ((product & LONG_MASK) << 32);
153                hScaled = scaleMetrics.mulhiByScaleFactor(hFactor) + hScaled + (product >>> 32);
154
155                // square root
156                // @see
157                // http://www.embedded.com/electronics-blogs/programmer-s-toolbox/4219659/Integer-Square-Roots
158                int zerosHalf;
159                long rem = 0;
160                long root = 0;
161
162                // iteration for high 32 bits
163                zerosHalf = Long.numberOfLeadingZeros(hScaled) >> 1;
164                hScaled <<= (zerosHalf << 1);
165                for (int i = zerosHalf; i < 32; i++) {
166                        root <<= 1;
167                        rem = ((rem << 2) + (hScaled >>> 62));
168                        hScaled <<= 2;
169                        root++;
170                        if (root <= rem) {
171                                rem -= root;
172                                root++;
173                        } else {
174                                root--;
175                        }
176                }
177
178                // iteration for low 32 bits (last iteration below)
179                zerosHalf = zerosHalf == 32 ? Long.numberOfLeadingZeros(lScaled) >> 1 : 0;
180                lScaled <<= (zerosHalf << 1);
181                for (int i = zerosHalf; i < 31; i++) {
182                        root <<= 1;
183                        rem = ((rem << 2) + (lScaled >>> 62));
184                        lScaled <<= 2;
185                        root++;
186                        if (root <= rem) {
187                                rem -= root;
188                                root++;
189                        } else {
190                                root--;
191                        }
192                }
193
194                // last iteration needs unsigned compare
195                root <<= 1;
196                rem = ((rem << 2) + (lScaled >>> 62));
197                lScaled <<= 2;
198                root++;
199                if (Unsigned.isLessOrEqual(root, rem)) {
200                        rem -= root;
201                        root++;
202                } else {
203                        root--;
204                }
205
206                // round result if necessary
207                final long truncated = root >>> 1;
208                if (rem == 0 | rounding == DecimalRounding.DOWN | rounding == DecimalRounding.FLOOR) {
209                        return truncated;
210                }
211                return truncated + getRoundingIncrement(rounding, truncated, rem);
212        }
213
214        // PRECONDITION: rem != 0
215        // NOTE: TruncatedPart cannot be 0.5 because this would square to 0.25
216        private static final int getRoundingIncrement(DecimalRounding rounding, long truncated, long rem) {
217                if (truncated < rem) {
218                        return rounding.calculateRoundingIncrement(1, truncated, TruncatedPart.GREATER_THAN_HALF);
219                }
220                return rounding.calculateRoundingIncrement(1, truncated, TruncatedPart.LESS_THAN_HALF_BUT_NOT_ZERO);
221        }
222
223        // no instances
224        private Sqrt() {
225                super();
226        }
227}