wireguard-vanity-key/src/cpu/x25519.rs

552 lines
16 KiB
Rust

// Ripped from boringtun lmao
// Copyright (c) 2019 Cloudflare, Inc. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
//! Elliptic-curve Diffie-Hellman exchange over Curve25519.
use crate::Key;
use std::ops::Add;
use std::ops::Mul;
use std::ops::Sub;
#[inline(always)]
fn make_array<A, T>(slice: &[T]) -> A
where
A: Sized + Default + AsMut<[T]> + std::borrow::Borrow<[T]>,
T: Copy,
{
let mut arr: A = Default::default();
let arr_len = arr.borrow().len();
<A as AsMut<[T]>>::as_mut(&mut arr).copy_from_slice(&slice[0..arr_len]);
arr
}
const MASK_63BITS: u128 = 0x7fff_ffff_ffff_ffff;
const MASK_64BITS: u128 = 0xffff_ffff_ffff_ffff;
#[derive(Clone, Copy)]
// Internal structs for fast arithmetic
struct Felem([u64; 4]);
struct Felem2([u64; 8]);
#[cfg_attr(feature = "cargo-clippy", allow(clippy::suspicious_arithmetic_impl))]
impl Add for Felem {
type Output = Felem;
#[inline(always)]
// Addition modulo 2^255 - 19
fn add(self, other: Felem) -> Felem {
let x0 = u128::from(self.0[0]);
let x1 = u128::from(self.0[1]);
let x2 = u128::from(self.0[2]);
let x3 = u128::from(self.0[3]);
let y0 = u128::from(other.0[0]);
let y1 = u128::from(other.0[1]);
let y2 = u128::from(other.0[2]);
let y3 = u128::from(other.0[3]);
let mut acc0 = x0.wrapping_add(y0);
let mut acc1 = x1.wrapping_add(y1).wrapping_add(acc0 >> 64);
let mut acc2 = x2.wrapping_add(y2).wrapping_add(acc1 >> 64);
let mut acc3 = x3.wrapping_add(y3).wrapping_add(acc2 >> 64);
let mut top = (acc3 >> 63) & 0xffff_ffff_ffff_ffff;
acc0 &= 0xffff_ffff_ffff_ffff;
acc1 &= 0xffff_ffff_ffff_ffff;
acc2 &= 0xffff_ffff_ffff_ffff;
acc3 &= 0x7fff_ffff_ffff_ffff;
top = top.wrapping_mul(19);
acc0 = acc0.wrapping_add(top);
acc1 = acc1.wrapping_add(acc0 >> 64);
acc2 = acc2.wrapping_add(acc1 >> 64);
acc3 = acc3.wrapping_add(acc2 >> 64);
Felem([acc0 as u64, acc1 as u64, acc2 as u64, acc3 as u64])
}
}
#[cfg_attr(feature = "cargo-clippy", allow(clippy::suspicious_arithmetic_impl))]
impl Sub for Felem {
type Output = Felem;
#[inline(always)]
// Subtraction modulo 2^255 - 19
fn sub(self, other: Felem) -> Felem {
static POLY_X4: [u128; 4] = [
0x1_ffff_ffff_ffff_ffb4,
0x1_ffff_ffff_ffff_fffe,
0x1_ffff_ffff_ffff_fffe,
0x1_ffff_ffff_ffff_fffe,
];
let x0 = u128::from(self.0[0]);
let x1 = u128::from(self.0[1]);
let x2 = u128::from(self.0[2]);
let x3 = u128::from(self.0[3]);
let y0 = u128::from(other.0[0]);
let y1 = u128::from(other.0[1]);
let y2 = u128::from(other.0[2]);
let y3 = u128::from(other.0[3]);
let mut acc0 = POLY_X4[0].wrapping_sub(y0).wrapping_add(x0);
let mut acc1 = POLY_X4[1]
.wrapping_sub(y1)
.wrapping_add(x1)
.wrapping_add(acc0 >> 64);
let mut acc2 = POLY_X4[2]
.wrapping_sub(y2)
.wrapping_add(x2)
.wrapping_add(acc1 >> 64);
let mut acc3 = POLY_X4[3]
.wrapping_sub(y3)
.wrapping_add(x3)
.wrapping_add(acc2 >> 64);
let mut top = (acc3 >> 63) & 0xffff_ffff_ffff_ffff;
acc0 &= 0xffff_ffff_ffff_ffff;
acc1 &= 0xffff_ffff_ffff_ffff;
acc2 &= 0xffff_ffff_ffff_ffff;
acc3 &= 0x7fff_ffff_ffff_ffff;
top = top.wrapping_mul(19);
acc0 = acc0.wrapping_add(top);
acc1 = acc1.wrapping_add(acc0 >> 64);
acc2 = acc2.wrapping_add(acc1 >> 64);
acc3 = acc3.wrapping_add(acc2 >> 64);
Felem([acc0 as u64, acc1 as u64, acc2 as u64, acc3 as u64])
}
}
#[cfg_attr(feature = "cargo-clippy", allow(clippy::suspicious_arithmetic_impl))]
impl Mul for Felem {
type Output = Felem;
#[inline(always)]
// Multiplication modulo 2^255 - 19
fn mul(self, other: Felem) -> Felem {
let x0 = u128::from(self.0[0]);
let x1 = u128::from(self.0[1]);
let x2 = u128::from(self.0[2]);
let x3 = u128::from(self.0[3]);
// y0
let y0 = u128::from(other.0[0]);
let mut t = x0.wrapping_mul(y0);
let acc0 = t & 0xffff_ffff_ffff_ffff;
let mut acc1 = t >> 64;
t = x1.wrapping_mul(y0);
acc1 = acc1.wrapping_add(t);
let mut acc2 = acc1 >> 64;
acc1 &= 0xffff_ffff_ffff_ffff;
t = x2.wrapping_mul(y0);
acc2 = acc2.wrapping_add(t);
let mut acc3 = acc2 >> 64;
acc2 &= 0xffff_ffff_ffff_ffff;
t = x3.wrapping_mul(y0);
acc3 = acc3.wrapping_add(t);
let mut acc4 = acc3 >> 64;
acc3 &= 0xffff_ffff_ffff_ffff;
// y1
let y1 = u128::from(other.0[1]);
t = x0.wrapping_mul(y1);
acc1 = acc1.wrapping_add(t);
let mut top = acc1 >> 64;
acc1 &= 0xffff_ffff_ffff_ffff;
t = x1.wrapping_mul(y1);
acc2 = acc2.wrapping_add(top);
acc2 = acc2.wrapping_add(t);
top = acc2 >> 64;
acc2 &= 0xffff_ffff_ffff_ffff;
t = x2.wrapping_mul(y1);
acc3 = acc3.wrapping_add(top);
acc3 = acc3.wrapping_add(t);
top = acc3 >> 64;
acc3 &= 0xffff_ffff_ffff_ffff;
t = x3.wrapping_mul(y1);
acc4 = acc4.wrapping_add(top);
acc4 = acc4.wrapping_add(t);
let mut acc5 = acc4 >> 64;
acc4 &= 0xffff_ffff_ffff_ffff;
// y2
let y2 = u128::from(other.0[2]);
t = x0.wrapping_mul(y2);
acc2 = acc2.wrapping_add(t);
top = acc2 >> 64;
acc2 &= 0xffff_ffff_ffff_ffff;
t = x1.wrapping_mul(y2);
acc3 = acc3.wrapping_add(top);
acc3 = acc3.wrapping_add(t);
top = acc3 >> 64;
acc3 &= 0xffff_ffff_ffff_ffff;
t = x2.wrapping_mul(y2);
acc4 = acc4.wrapping_add(top);
acc4 = acc4.wrapping_add(t);
top = acc4 >> 64;
acc4 &= 0xffff_ffff_ffff_ffff;
t = x3.wrapping_mul(y2);
acc5 = acc5.wrapping_add(top);
acc5 = acc5.wrapping_add(t);
let mut acc6 = acc5 >> 64;
acc5 &= 0xffff_ffff_ffff_ffff;
// y3
let y3 = u128::from(other.0[3]);
t = x0.wrapping_mul(y3);
acc3 = acc3.wrapping_add(t);
top = acc3 >> 64;
acc3 &= 0xffff_ffff_ffff_ffff;
t = x1.wrapping_mul(y3);
acc4 = acc4.wrapping_add(top);
acc4 = acc4.wrapping_add(t);
top = acc4 >> 64;
acc4 &= 0xffff_ffff_ffff_ffff;
t = x2.wrapping_mul(y3);
acc5 = acc5.wrapping_add(top);
acc5 = acc5.wrapping_add(t);
top = acc5 >> 64;
acc5 &= 0xffff_ffff_ffff_ffff;
t = x3.wrapping_mul(y3);
acc6 = acc6.wrapping_add(top);
acc6 = acc6.wrapping_add(t);
let acc7 = acc6 >> 64;
acc6 &= 0xffff_ffff_ffff_ffff;
// Modulo
mod_25519(Felem2([
acc0 as u64,
acc1 as u64,
acc2 as u64,
acc3 as u64,
acc4 as u64,
acc5 as u64,
acc6 as u64,
acc7 as u64,
]))
}
}
impl Felem {
#[inline(always)]
// Repeatedly square modulo 2^255 - 19
fn sqr(self, mut rep: u32) -> Felem {
let mut ret = self;
while rep > 0 {
ret = mod_25519(sqr_256(ret));
rep -= 1;
}
ret
}
}
#[inline(always)]
// Square modulo 2^255 - 19
fn sqr_256(x: Felem) -> Felem2 {
let x0 = u128::from(x.0[0]);
let x1 = u128::from(x.0[1]);
let x2 = u128::from(x.0[2]);
let x3 = u128::from(x.0[3]);
// y0
let mut acc1 = x1.wrapping_mul(x0);
let mut acc2 = x2.wrapping_mul(x0);
let mut acc3 = x3.wrapping_mul(x0);
acc2 = acc2.wrapping_add(acc1 >> 64);
acc3 = acc3.wrapping_add(acc2 >> 64);
let mut acc4 = acc3 >> 64;
acc1 &= 0xffff_ffff_ffff_ffff;
acc2 &= 0xffff_ffff_ffff_ffff;
acc3 &= 0xffff_ffff_ffff_ffff;
// y1
let mut t = x2.wrapping_mul(x1);
acc3 = acc3.wrapping_add(t);
t = x3.wrapping_mul(x1);
acc4 = acc4.wrapping_add(acc3 >> 64).wrapping_add(t);
let mut acc5 = acc4 >> 64;
acc3 &= 0xffff_ffff_ffff_ffff;
acc4 &= 0xffff_ffff_ffff_ffff;
// y2
t = x3.wrapping_mul(x2);
acc5 = acc5.wrapping_add(t);
let mut acc6 = acc5 >> 64;
acc5 &= 0xffff_ffff_ffff_ffff;
acc6 = acc6 << 1 | acc5 >> 63;
acc5 = acc5 << 1 | acc4 >> 63;
acc4 = acc4 << 1 | acc3 >> 63;
acc3 = acc3 << 1 | acc2 >> 63;
acc2 = acc2 << 1 | acc1 >> 63;
acc1 <<= 1;
let mut acc7 = acc6 >> 64;
acc1 &= 0xffff_ffff_ffff_ffff;
acc2 &= 0xffff_ffff_ffff_ffff;
acc3 &= 0xffff_ffff_ffff_ffff;
acc4 &= 0xffff_ffff_ffff_ffff;
acc5 &= 0xffff_ffff_ffff_ffff;
acc6 &= 0xffff_ffff_ffff_ffff;
let acc0 = x0.wrapping_mul(x0);
acc1 = acc1.wrapping_add(acc0 >> 64);
t = x1.wrapping_mul(x1);
acc2 = acc2.wrapping_add(acc1 >> 64).wrapping_add(t);
acc3 = acc3.wrapping_add(acc2 >> 64);
t = x2.wrapping_mul(x2);
acc4 = acc4.wrapping_add(acc3 >> 64).wrapping_add(t);
acc5 = acc5.wrapping_add(acc4 >> 64);
t = x3.wrapping_mul(x3);
acc6 = acc6.wrapping_add(acc5 >> 64).wrapping_add(t);
acc7 = acc7.wrapping_add(acc6 >> 64);
Felem2([
acc0 as u64,
acc1 as u64,
acc2 as u64,
acc3 as u64,
acc4 as u64,
acc5 as u64,
acc6 as u64,
acc7 as u64,
])
// Modulo
}
#[inline(always)]
fn mod_25519(x: Felem2) -> Felem {
let c38 = 38_u128;
let mut acc0 = u128::from(x.0[0]);
let mut acc1 = u128::from(x.0[1]);
let mut acc2 = u128::from(x.0[2]);
let mut acc3 = u128::from(x.0[3]);
let mut acc4 = u128::from(x.0[4]);
let mut acc5 = u128::from(x.0[5]);
let mut acc6 = u128::from(x.0[6]);
let mut acc7 = u128::from(x.0[7]);
acc4 = acc4.wrapping_mul(c38);
acc5 = acc5.wrapping_mul(c38);
acc6 = acc6.wrapping_mul(c38);
acc7 = acc7.wrapping_mul(c38);
acc0 = acc0.wrapping_add(acc4);
acc1 = acc1.wrapping_add(acc0 >> 64);
acc1 = acc1.wrapping_add(acc5);
acc2 = acc2.wrapping_add(acc1 >> 64);
acc2 = acc2.wrapping_add(acc6);
acc3 = acc3.wrapping_add(acc2 >> 64);
acc3 = acc3.wrapping_add(acc7);
let mut top = (acc3 >> 63) & 0xffff_ffff_ffff_ffff;
acc0 &= 0xffff_ffff_ffff_ffff;
acc1 &= 0xffff_ffff_ffff_ffff;
acc2 &= 0xffff_ffff_ffff_ffff;
acc3 &= 0x7fff_ffff_ffff_ffff;
top = top.wrapping_mul(19);
acc0 = acc0.wrapping_add(top);
acc1 = acc1.wrapping_add(acc0 >> 64);
acc2 = acc2.wrapping_add(acc1 >> 64);
acc3 = acc3.wrapping_add(acc2 >> 64);
Felem([acc0 as u64, acc1 as u64, acc2 as u64, acc3 as u64])
}
fn mod_final_25519(x: Felem) -> Felem {
let mut acc0 = u128::from(x.0[0]);
let mut acc1 = u128::from(x.0[1]);
let mut acc2 = u128::from(x.0[2]);
let mut acc3 = u128::from(x.0[3]);
let mut top = acc3 >> 63;
acc3 &= MASK_63BITS;
top = top.wrapping_mul(19);
acc0 = acc0.wrapping_add(top);
acc1 = acc1.wrapping_add(acc0 >> 64);
acc2 = acc2.wrapping_add(acc1 >> 64);
acc3 = acc3.wrapping_add(acc2 >> 64);
// Mask
acc0 &= MASK_64BITS;
acc1 &= MASK_64BITS;
acc2 &= MASK_64BITS;
acc3 &= MASK_64BITS;
// At this point, acc{0-3} is in the range between 0 and 2^255 + 18, inclusively. It's not
// under 2^255 - 19 yet. So we are doing another round of modulo operation.
top = acc0.wrapping_add(19) >> 64;
top = acc1.wrapping_add(top) >> 64;
top = acc2.wrapping_add(top) >> 64;
top = acc3.wrapping_add(top) >> 63;
top = top.wrapping_mul(19);
// top is 19 if acc{0-3} is between 2^255 - 19 and 2^255 + 18, inclusively. Otherwise, it's
// zero.
acc0 = acc0.wrapping_add(top);
acc1 = acc1.wrapping_add(acc0 >> 64);
acc2 = acc2.wrapping_add(acc1 >> 64);
acc3 = acc3.wrapping_add(acc2 >> 64);
acc3 &= MASK_63BITS;
// Now acc{0-3} is between 0 and 2^255 - 20, inclusively.
Felem([acc0 as u64, acc1 as u64, acc2 as u64, acc3 as u64])
}
// Modular inverse
fn mod_inv_25519(x: Felem) -> Felem {
let m1 = x;
let m10 = x.sqr(1);
let m1001 = m10.sqr(2) * m1;
let m1011 = m1001 * m10;
let x5 = m1011.sqr(1) * m1001;
let x10 = x5.sqr(5) * x5;
let x20 = x10.sqr(10) * x10;
let x40 = x20.sqr(20) * x20;
let x50 = x40.sqr(10) * x10;
let x100 = x50.sqr(50) * x50;
let t = x100.sqr(100) * x100;
let t2 = t.sqr(50) * x50;
t2.sqr(5) * m1011
}
#[inline(always)]
// Swap two values a and b in constant time iff swap == 1
fn constant_time_swap(a: Felem, b: Felem, swap: u64) -> (Felem, Felem) {
let mask = 0_u64.wrapping_sub(swap);
let mut v = [0_u64; 4];
let mut a_out = [0_u64; 4];
let mut b_out = [0_u64; 4];
v[0] = mask & (a.0[0] ^ b.0[0]);
v[1] = mask & (a.0[1] ^ b.0[1]);
v[2] = mask & (a.0[2] ^ b.0[2]);
v[3] = mask & (a.0[3] ^ b.0[3]);
a_out[0] = v[0] ^ a.0[0];
a_out[1] = v[1] ^ a.0[1];
a_out[2] = v[2] ^ a.0[2];
a_out[3] = v[3] ^ a.0[3];
b_out[0] = v[0] ^ b.0[0];
b_out[1] = v[1] ^ b.0[1];
b_out[2] = v[2] ^ b.0[2];
b_out[3] = v[3] ^ b.0[3];
(Felem(a_out), Felem(b_out))
}
#[inline]
pub fn x25519_shared_key(peer_key: &[u8], secret_key: &[u8]) -> Key {
if peer_key.len() != 32 || secret_key.len() != 32 {
panic!("Illegal values for x25519");
}
let mut scalar = [0_u8; 32];
let mut shared_key = [0_u8; 32];
scalar[..].copy_from_slice(secret_key);
assert!(peer_key.len() == 32);
let u = Felem([
u64::from_le_bytes(make_array(&peer_key[0..])),
u64::from_le_bytes(make_array(&peer_key[8..])),
u64::from_le_bytes(make_array(&peer_key[16..])),
u64::from_le_bytes(make_array(&peer_key[24..])),
]);
scalar[0] &= 248;
scalar[31] &= 127;
scalar[31] |= 64;
let x_1 = u;
let mut x_2 = Felem([1, 0, 0, 0]);
let mut z_2 = Felem([0, 0, 0, 0]);
let mut x_3 = u;
let mut z_3 = Felem([1, 0, 0, 0]);
let a24 = Felem([121_666, 0, 0, 0]);
let mut swap = 0;
for pos in (0..=254).rev() {
let bit_val = u64::from((scalar[pos / 8] >> (pos & 7)) & 1);
swap ^= bit_val;
let (mut x2, mut x3) = constant_time_swap(x_2, x_3, swap);
let (mut z2, mut z3) = constant_time_swap(z_2, z_3, swap);
swap = bit_val;
let mut tmp0 = x3 - z3;
let mut tmp1 = x2 - z2;
x2 = x2 + z2;
z2 = x3 + z3;
z3 = x2 * tmp0;
z2 = z2 * tmp1;
tmp0 = tmp1.sqr(1);
tmp1 = x2.sqr(1);
x3 = z3 + z2;
z2 = z3 - z2;
x_2 = tmp1 * tmp0;
tmp1 = tmp1 - tmp0;
z2 = z2.sqr(1);
z3 = a24 * tmp1;
x_3 = x3.sqr(1);
tmp0 = tmp0 + z3;
z_3 = x_1 * z2;
z_2 = tmp1 * tmp0;
}
let (x2, _) = constant_time_swap(x_2, x_3, swap);
let (z2, _) = constant_time_swap(z_2, z_3, swap);
let key = mod_final_25519(x2 * mod_inv_25519(z2));
shared_key[0..8].copy_from_slice(&key.0[0].to_le_bytes());
shared_key[8..16].copy_from_slice(&key.0[1].to_le_bytes());
shared_key[16..24].copy_from_slice(&key.0[2].to_le_bytes());
shared_key[24..32].copy_from_slice(&key.0[3].to_le_bytes());
shared_key
}