diff --git a/lib/hcrypto/imath/imath.c b/lib/hcrypto/imath/imath.c index 7063c4c8d..920e6ed26 100755 --- a/lib/hcrypto/imath/imath.c +++ b/lib/hcrypto/imath/imath.c @@ -725,10 +725,11 @@ mp_result mp_int_mul(mp_int a, mp_int b, mp_int c) /* Output is positive if inputs have same sign, otherwise negative */ osign = (MP_SIGN(a) == MP_SIGN(b)) ? MP_ZPOS : MP_NEG; - /* If the output is not equal to any of the inputs, we'll write the - results there directly; otherwise, allocate a temporary space. */ + /* If the output is not identical to any of the inputs, we'll write + the results directly; otherwise, allocate a temporary space. */ ua = MP_USED(a); ub = MP_USED(b); - osize = ua + ub; + osize = MAX(ua, ub); + osize = 4 * ((osize + 1) / 2); if(c == a || c == b) { p = ROUND_PREC(osize); @@ -809,7 +810,7 @@ mp_result mp_int_sqr(mp_int a, mp_int c) CHECK(a != NULL && c != NULL); /* Get a temporary buffer big enough to hold the result */ - osize = (mp_size) 2 * MP_USED(a); + osize = (mp_size) 4 * ((MP_USED(a) + 1) / 2); if(a == c) { p = ROUND_PREC(osize); p = MAX(p, default_precision); @@ -2308,26 +2309,26 @@ static int s_kmul(mp_digit *da, mp_digit *db, mp_digit *dc, product; twice the space is plenty. */ if((t1 = s_alloc(4 * buf_size)) == NULL) return 0; - t2 = t1 + buf_size; + t2 = t1 + buf_size; t3 = t2 + buf_size; ZERO(t1, 4 * buf_size); /* t1 and t2 are initially used as temporaries to compute the inner product (a1 + a0)(b1 + b0) = a1b1 + a1b0 + a0b1 + a0b0 */ - carry = s_uadd(da, a_top, t1, bot_size, at_size); /* t1 = a1 + a0 */ + carry = s_uadd(da, a_top, t1, bot_size, at_size); /* t1 = a1 + a0 */ t1[bot_size] = carry; - carry = s_uadd(db, b_top, t2, bot_size, bt_size); /* t2 = b1 + b0 */ + carry = s_uadd(db, b_top, t2, bot_size, bt_size); /* t2 = b1 + b0 */ t2[bot_size] = carry; - (void) s_kmul(t1, t2, t3, bot_size + 1, bot_size + 1); /* t3 = t1 * t2 */ + (void) s_kmul(t1, t2, t3, bot_size + 1, bot_size + 1); /* t3 = t1 * t2 */ /* Now we'll get t1 = a0b0 and t2 = a1b1, and subtract them out so that we're left with only the pieces we want: t3 = a1b0 + a0b1 */ - ZERO(t1, bot_size + 1); - ZERO(t2, bot_size + 1); + ZERO(t1, buf_size); + ZERO(t2, buf_size); (void) s_kmul(da, db, t1, bot_size, bot_size); /* t1 = a0 * b0 */ (void) s_kmul(a_top, b_top, t2, at_size, bt_size); /* t2 = a1 * b1 */ @@ -2337,11 +2338,13 @@ static int s_kmul(mp_digit *da, mp_digit *db, mp_digit *dc, /* Assemble the output value */ COPY(t1, dc, buf_size); - (void) s_uadd(t3, dc + bot_size, dc + bot_size, - buf_size + 1, buf_size + 1); + carry = s_uadd(t3, dc + bot_size, dc + bot_size, + buf_size + 1, buf_size); + assert(carry == 0); - (void) s_uadd(t2, dc + 2*bot_size, dc + 2*bot_size, - buf_size, buf_size); + carry = s_uadd(t2, dc + 2*bot_size, dc + 2*bot_size, + buf_size, buf_size); + assert(carry == 0); s_free(t1); /* note t2 and t3 are just internal pointers to t1 */ } @@ -2390,7 +2393,7 @@ static int s_ksqr(mp_digit *da, mp_digit *dc, mp_size size_a) if(multiply_threshold && size_a > multiply_threshold) { mp_size bot_size = (size_a + 1) / 2; mp_digit *a_top = da + bot_size; - mp_digit *t1, *t2, *t3; + mp_digit *t1, *t2, *t3, carry; mp_size at_size = size_a - bot_size; mp_size buf_size = 2 * bot_size; @@ -2420,11 +2423,13 @@ static int s_ksqr(mp_digit *da, mp_digit *dc, mp_size size_a) /* Assemble the output value */ COPY(t1, dc, 2 * bot_size); - (void) s_uadd(t3, dc + bot_size, dc + bot_size, - buf_size + 1, buf_size + 1); - - (void) s_uadd(t2, dc + 2*bot_size, dc + 2*bot_size, - buf_size, buf_size); + carry = s_uadd(t3, dc + bot_size, dc + bot_size, + buf_size + 1, buf_size); + assert(carry == 0); + + carry = s_uadd(t2, dc + 2*bot_size, dc + 2*bot_size, + buf_size, buf_size); + assert(carry == 0); s_free(t1); /* note that t2 and t2 are internal pointers only */