001/**
002 * The MIT License (MIT)
003 *
004 * Copyright (c) 2015-2016 decimal4j (tools4j), Marco Terzer
005 *
006 * Permission is hereby granted, free of charge, to any person obtaining a copy
007 * of this software and associated documentation files (the "Software"), to deal
008 * in the Software without restriction, including without limitation the rights
009 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
010 * copies of the Software, and to permit persons to whom the Software is
011 * furnished to do so, subject to the following conditions:
012 *
013 * The above copyright notice and this permission notice shall be included in all
014 * copies or substantial portions of the Software.
015 *
016 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
017 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
018 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
019 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
020 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
021 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
022 * SOFTWARE.
023 */
024package org.decimal4j.arithmetic;
025
026import org.decimal4j.api.DecimalArithmetic;
027import org.decimal4j.scale.ScaleMetrics;
028import org.decimal4j.truncate.DecimalRounding;
029import org.decimal4j.truncate.OverflowMode;
030import org.decimal4j.truncate.TruncatedPart;
031
032/**
033 * Contains static methods to calculate powers of a Decimal number.
034 */
035final class Pow {
036
037        /**
038         * Constant for {@code floor(sqrt(Long.MAX_VALUE))}
039         */
040        private static final long FLOOR_SQRT_MAX_LONG = 3037000499L;
041
042        private static final void checkExponent(int exponent) {
043                if (exponent < -999999999 || exponent > 999999999) {
044                        throw new IllegalArgumentException("Exponent must be in [-999999999,999999999] but was: " + exponent);
045                }
046        }
047
048        /**
049         * Calculates the power <tt>(lBase<sup>exponent</sup>)</tt>. Overflows are
050         * silently ignored.
051         * 
052         * @param arith
053         *            the arithmetic associated with {@code lBase}
054         * @param rounding
055         *            the rounding to apply if rounding is necessary for negative
056         *            exponents
057         * @param lBase
058         *            the unscaled decimal base value
059         * @param exponent
060         *            the exponent
061         * @return <tt>round(lBase<sup>exponent</sup>)</tt>
062         * @throws ArithmeticException
063         *             if {@code lBase==0} and the exponent is negative or if
064         *             {@code roundingMode==UNNECESSARY} and rounding is necessary
065         */
066        public static final long powLong(DecimalArithmetic arith, DecimalRounding rounding, long lBase, int exponent) {
067                checkExponent(exponent);
068                final SpecialPowResult special = SpecialPowResult.getFor(arith, lBase, exponent);
069                if (special != null) {
070                        return special.pow(arith, lBase, exponent);
071                }
072                return powLong(rounding, lBase, exponent);
073        }
074
075        private static final long powLong(DecimalRounding rounding, long lBase, int exponent) {
076                if (exponent >= 0) {
077                        return powLongWithPositiveExponent(lBase, exponent);
078                } else {
079                        // result is 1/powered
080                        // we have dealt with special cases above hence powered is neither
081                        // of 0, 1, -1
082                        // and everything else can't be 0.5 because sqrt_i(0.5) is not real
083                        final int sgn = lBase > 0 | (exponent & 0x1) == 0 ? 1 : -1;// lBase
084                                                                                                                                                // cannot
085                                                                                                                                                // be 0
086                        return rounding.calculateRoundingIncrement(sgn, 0, TruncatedPart.LESS_THAN_HALF_BUT_NOT_ZERO);
087                }
088        }
089
090        /**
091         * Calculates the power <tt>(lBase<sup>exponent</sup>)</tt>. An exception is
092         * thrown if an overflow occurs.
093         * 
094         * @param arith
095         *            the arithmetic associated with {@code lBase}
096         * @param rounding
097         *            the rounding to apply if rounding is necessary for negative
098         *            exponents
099         * @param lBase
100         *            the unscaled decimal base value
101         * @param exponent
102         *            the exponent
103         * @return <tt>round(lBase<sup>exponent</sup>)</tt>
104         * @throws ArithmeticException
105         *             if {@code lBase==0} and the exponent is negative, if
106         *             {@code roundingMode==UNNECESSARY} and rounding is necessary
107         *             or if an overflow occurs and the arithmetic's
108         *             {@link OverflowMode} is set to throw an exception
109         */
110        public static final long powLongChecked(DecimalArithmetic arith, DecimalRounding rounding, long lBase, int exponent) {
111                checkExponent(exponent);
112                final SpecialPowResult special = SpecialPowResult.getFor(arith, lBase, exponent);
113                if (special != null) {
114                        return special.pow(arith, lBase, exponent);
115                }
116                return powLongChecked(rounding, lBase, exponent);
117        }
118
119        private static final long powLongChecked(DecimalRounding rounding, long lBase, int exponent) {
120                if (exponent >= 0) {
121                        return powLongCheckedWithPositiveExponent(lBase, exponent);
122                } else {
123                        // result is 1/powered
124                        // we have dealt with special cases above hence powered is neither
125                        // of 0, 1, -1
126                        // and everything else can't be 0.5 because sqrt_i(0.5) is not real
127                        final int sgn = lBase > 0 | (exponent & 0x1) == 0 ? 1 : -1;// lBase
128                                                                                                                                                // cannot
129                                                                                                                                                // be 0
130                        return rounding.calculateRoundingIncrement(sgn, 0, TruncatedPart.LESS_THAN_HALF_BUT_NOT_ZERO);
131                }
132        }
133
134        private static final long powLongCheckedOrUnchecked(OverflowMode overflowMode, DecimalRounding rounding, long longBase, int exponent) {
135                return overflowMode == OverflowMode.UNCHECKED ? powLong(rounding, longBase, exponent)
136                                : powLongChecked(rounding, longBase, exponent);
137        }
138
139        /**
140         * Power function for checked or unchecked arithmetic. The result is within
141         * 1 ULP for positive exponents.
142         * 
143         * @param arith
144         *            the arithmetic
145         * @param rounding
146         *            the rounding to apply
147         * @param uDecimalBase
148         *            the unscaled base
149         * @param exponent
150         *            the exponent
151         * @return {@code uDecimalbase ^ exponent}
152         */
153        public static final long pow(DecimalArithmetic arith, DecimalRounding rounding, long uDecimalBase, int exponent) {
154                checkExponent(exponent);
155                final SpecialPowResult special = SpecialPowResult.getFor(arith, uDecimalBase, exponent);
156                if (special != null) {
157                        return special.pow(arith, uDecimalBase, exponent);
158                }
159
160                // some other special cases
161                final ScaleMetrics scaleMetrics = arith.getScaleMetrics();
162
163                final long intVal = scaleMetrics.divideByScaleFactor(uDecimalBase);
164                final long fraVal = uDecimalBase - scaleMetrics.multiplyByScaleFactor(intVal);
165                if (exponent >= 0 & fraVal == 0) {
166                        // integer
167                        final long result = powLongCheckedOrUnchecked(arith.getOverflowMode(), rounding, intVal, exponent);
168                        return longToUnscaledCheckedOrUnchecekd(arith, uDecimalBase, exponent, result);
169                }
170                if (exponent < 0 & intVal == 0) {
171                        final long one = scaleMetrics.getScaleFactor();
172                        if ((one % fraVal) == 0) {
173                                // inverted value is an integer
174                                final long result = powLongCheckedOrUnchecked(arith.getOverflowMode(), rounding, one / fraVal,
175                                                -exponent);
176                                return longToUnscaledCheckedOrUnchecekd(arith, uDecimalBase, exponent, result);
177                        }
178                }
179                try {
180                        return powWithPrecision18(arith, rounding, intVal, fraVal, exponent);
181                } catch (IllegalArgumentException e) {
182                        throw new ArithmeticException("Overflow: " + arith.toString(uDecimalBase) + "^" + exponent);
183                }
184        }
185
186        // PRECONDITION: n != 0 and n in [-999999999,999999999]
187        private static final long powWithPrecision18(DecimalArithmetic arith, DecimalRounding rounding, long ival, long fval, int n) {
188                // eliminate sign
189                final int sgn = ((n & 0x1) != 0) ? Long.signum(ival | fval) : 1;
190                final long absInt = Math.abs(ival);
191                final long absFra = Math.abs(fval);
192                final DecimalRounding powRounding = n >= 0 ? rounding : RoundingInverse.RECIPROCAL.invert(rounding);
193
194                // 36 digit left hand side, initialized with base value
195                final UnsignedDecimal9i36f lhs = UnsignedDecimal9i36f.THREAD_LOCAL_1.get().init(absInt, absFra,
196                                arith.getScaleMetrics());
197
198                // 36 digit accumulator, initialized with one
199                final UnsignedDecimal9i36f acc = UnsignedDecimal9i36f.THREAD_LOCAL_2.get().initOne();
200
201                // ready to carry out power calculation...
202                int mag = Math.abs(n);
203                boolean seenbit = false; // avoid squaring ONE
204                for (int i = 1;; i++) { // for each bit [top bit ignored]
205                        mag += mag; // shift left 1 bit
206                        if (mag < 0) { // top bit is set
207                                if (seenbit) {
208                                        acc.multiply(sgn, lhs, powRounding);// acc=acc*x
209                                } else {
210                                        seenbit = true;
211                                        acc.init(lhs); // acc=x
212                                }
213                        }
214                        if (i == 31) {
215                                break; // that was the last bit
216                        }
217                        if (seenbit) {
218                                acc.multiply(sgn, acc, powRounding); // acc=acc*acc [square]
219                        }// else (!seenbit) no point in squaring ONE
220                }
221
222                if (n < 0) {
223                        return acc.getInverted(sgn, arith, rounding, powRounding);
224                }
225                return acc.getDecimal(sgn, arith, rounding);
226        }
227
228        private static final long powLongWithPositiveExponent(long lBase, int exponent) {
229                assert(exponent > 0);
230
231                long accum = 1;
232                while (true) {
233                        switch (exponent) {
234                        case 0:
235                                return accum;
236                        case 1:
237                                return accum * lBase;
238                        default:
239                                if ((exponent & 1) != 0) {
240                                        accum *= lBase;
241                                }
242                                exponent >>= 1;
243                                if (exponent > 0) {
244                                        lBase *= lBase;
245                                }
246                        }
247                }
248        }
249
250        private static final long powLongCheckedWithPositiveExponent(long lBase, int exponent) {
251                assert(exponent > 0);
252                if (lBase >= -2 & lBase <= 2) {
253                        switch ((int) lBase) {
254                        case 0:
255                                return (exponent == 0) ? 1 : 0;
256                        case 1:
257                                return 1;
258                        case (-1):
259                                return ((exponent & 1) == 0) ? 1 : -1;
260                        case 2:
261                                if (exponent >= Long.SIZE - 1) {
262                                        throw new ArithmeticException("Overflow: " + lBase + "^" + exponent);
263                                }
264                                return 1L << exponent;
265                        case (-2):
266                                if (exponent >= Long.SIZE) {
267                                        throw new ArithmeticException("Overflow: " + lBase + "^" + exponent);
268                                }
269                                return ((exponent & 1) == 0) ? (1L << exponent) : (-1L << exponent);
270                        default:
271                                throw new AssertionError();
272                        }
273                }
274                long accum = 1;
275                while (true) {
276                        switch (exponent) {
277                        case 0:
278                                return accum;
279                        case 1:
280                                return Checked.multiplyLong(accum, lBase);
281                        default:
282                                if ((exponent & 1) != 0) {
283                                        accum = Checked.multiplyLong(accum, lBase);
284                                }
285                                exponent >>= 1;
286                                if (exponent > 0) {
287                                        if (lBase > FLOOR_SQRT_MAX_LONG | lBase < -FLOOR_SQRT_MAX_LONG) {
288                                                throw new ArithmeticException("Overflow: " + lBase + "^" + exponent);
289                                        }
290                                        lBase *= lBase;
291                                }
292                        }
293                }
294        }
295
296        private static final long longToUnscaledCheckedOrUnchecekd(DecimalArithmetic arith, long uBase, int exponent, long longResult) {
297                if (!arith.getOverflowMode().isChecked()) {
298                        return LongConversion.longToUnscaledUnchecked(arith.getScaleMetrics(), longResult);
299                }
300                try {
301                        return LongConversion.longToUnscaled(arith.getScaleMetrics(), longResult);
302                } catch (IllegalArgumentException e) {
303                        throw new ArithmeticException("Overflow: " + arith.toString(uBase) + "^" + exponent + "=" + longResult);
304                }
305        }
306
307        // no instances
308        private Pow() {
309        }
310
311}