//! Implementation of encryption for backups.
//!
//! The API for this module is modeled after
//! [AEAD](https://en.wikipedia.org/wiki/AEAD): authenticated
//! encryption with associated data. The AD is non-encrypted binary
//! data (usually small) that is stored with the ciphertext to allow
//! choosing the right ciphertext to use. The AD might, for example,
//! be an ID of the plain text data. The decryption requires both the
//! key and the exact AD used with encryption. This authenticates the
//! AD: if the decryption succeeds, the AD was correct. The security
//! comes from the key, which must be stored securely.
//!
//! The API is designed to allow new encryption methods to be added in
//! the future, but currently only provides a wrapper around the
//! [aes-gcm-siv](https://crates.io/crates/aes-gcm-siv) crate.

use aes_gcm_siv::{
    aead::{generic_array::GenericArray, rand_core::RngCore, Aead, KeyInit, OsRng, Payload},
    Aes256GcmSiv,
};
use serde::{Deserialize, Serialize};

/// What kind of cipher engine to use.
pub enum EngineKind {
    /// An AEAD cipher engine using the
    /// [aes-gcm-siv](https://crates.io/crates/aes-gcm-siv) crate
    Aead,
}

/// A cipher engine.
pub struct Engine {
    engine: ActualEngine,
}

impl Engine {
    /// Create a new [`Engine`] of the indicated kind, which uses the
    /// provided key for encryption.
    pub fn new(kind: EngineKind, key: Key) -> Self {
        Self {
            engine: match kind {
                EngineKind::Aead => ActualEngine::Aead(AeadEngine::new(key)),
            },
        }
    }

    /// Encrypt some data and associated data using the key provided
    /// at creation time.
    pub fn encrypt(&self, data: &[u8], ad: &[u8]) -> Result<Vec<u8>, CipherError> {
        match &self.engine {
            ActualEngine::Aead(e) => e.encrypt(data, ad),
        }
    }

    /// Decrypt some ciphertext using associated data and the key
    /// provided at creation time.
    pub fn decrypt(&self, ciphertext: &[u8], ad: &[u8]) -> Result<Vec<u8>, CipherError> {
        match &self.engine {
            ActualEngine::Aead(e) => e.decrypt(ciphertext, ad),
        }
    }
}

enum ActualEngine {
    Aead(AeadEngine),
}

struct AeadEngine {
    key: Key,
}

impl AeadEngine {
    fn new(key: Key) -> Self {
        Self { key }
    }

    fn encrypt(&self, data: &[u8], ad: &[u8]) -> Result<Vec<u8>, CipherError> {
        let key = GenericArray::from_slice(self.key.as_slice());
        let nonce = Nonce::default();
        let aes_gcm_siv_nonce = GenericArray::from_slice(&nonce.bytes);
        let cipher = Aes256GcmSiv::new(key);
        let payload = Payload { msg: data, aad: ad };
        let ciphertext = cipher
            .encrypt(aes_gcm_siv_nonce, payload)
            .map_err(|_| CipherError::Encrypt)?;
        let ciphertext = AeadCiphertext::new(ciphertext, nonce);
        ciphertext.to_vec()
    }

    fn decrypt(&self, ciphertext: &[u8], ad: &[u8]) -> Result<Vec<u8>, CipherError> {
        let ciphertext = AeadCiphertext::try_from(ciphertext)?;
        let key = GenericArray::from_slice(self.key.as_slice());
        let nonce = GenericArray::from_slice(ciphertext.nonce.as_slice());
        let cipher = Aes256GcmSiv::new(key);
        let payload = Payload {
            msg: ciphertext.as_bytes(),
            aad: ad,
        };
        let plain = cipher
            .decrypt(nonce, payload)
            .map_err(|_| CipherError::Decrypt)?;
        Ok(plain)
    }
}

/// Encrypted AEAD data.
#[derive(Debug, Serialize, Deserialize)]
pub struct AeadCiphertext {
    ciphertext: Vec<u8>,
    nonce: Nonce,
}

impl AeadCiphertext {
    fn new(blob: Vec<u8>, nonce: Nonce) -> Self {
        Self {
            ciphertext: blob,
            nonce,
        }
    }

    fn as_bytes(&self) -> &[u8] {
        &self.ciphertext
    }

    fn to_vec(&self) -> Result<Vec<u8>, CipherError> {
        postcard::to_allocvec(self).map_err(CipherError::AeadSerialize)
    }
}

impl TryFrom<&[u8]> for AeadCiphertext {
    type Error = CipherError;
    fn try_from(blob: &[u8]) -> Result<Self, Self::Error> {
        postcard::from_bytes(blob).map_err(CipherError::AeadParse)
    }
}

// A nonce for AEAD.
//
// Each encryption creates a new random nonce. For the AEAD we use, it
// must always be exactly 96 bit (12 bytes). The nonce is only used
// internally in this module.
#[derive(Debug, Serialize, Deserialize)]
struct Nonce {
    bytes: Vec<u8>,
}

impl Nonce {
    fn as_slice(&self) -> &[u8] {
        &self.bytes
    }
}

impl Default for Nonce {
    fn default() -> Self {
        const NONCE_BYTES: usize = 12;
        let mut bytes = [0; NONCE_BYTES];
        OsRng.fill_bytes(&mut bytes);
        Self {
            bytes: bytes.to_vec(),
        }
    }
}

/// A symmetric encryption key.
#[derive(Debug, Clone, Eq, PartialEq, Serialize, Deserialize)]
pub struct Key {
    key: Vec<u8>,
}

impl Key {
    /// Key as bytes.
    pub fn as_slice(&self) -> &[u8] {
        &self.key
    }
}

impl Default for Key {
    fn default() -> Self {
        Self {
            key: Aes256GcmSiv::generate_key(&mut OsRng).to_vec(),
        }
    }
}

impl From<&String> for Key {
    fn from(key: &String) -> Self {
        Self::from(key.as_bytes())
    }
}

impl From<&str> for Key {
    fn from(key: &str) -> Self {
        Self::from(key.as_bytes())
    }
}

impl From<&[u8]> for Key {
    fn from(key: &[u8]) -> Self {
        Self::from(key.to_vec())
    }
}

impl From<Vec<u8>> for Key {
    fn from(mut key: Vec<u8>) -> Self {
        // The `aes-gcm-siv` crate seems to want keys of a specific
        // length. FOR TESTING, we want to avoid that, so we pad or
        // truncate as needed.
        {
            let wanted = Self::default().key.len();
            key.truncate(wanted);
            while key.len() < wanted {
                key.push(b'\0');
            }
        }

        Self { key }
    }
}

/// All possible errors from the `cipher` module.
#[derive(Debug, thiserror::Error)]
pub enum CipherError {
    /// Encryption failed.
    #[error("encryption failed")]
    Encrypt,

    /// Decryption failed.
    #[error("decryption failed")]
    Decrypt,

    /// Can't serialize an AEAD ciphertext.
    #[error("failed to serialize AEAD ciphertext")]
    AeadSerialize(#[source] postcard::Error),

    /// Serialized AEAD ciphertext doesn't start with the expected cookie.
    #[error("encrypted data doesn't start with magic cookie")]
    AeadNoCookie,

    /// Can't de-serialize an AEAD ciphertext.
    #[error("failed to parse AEAD ciphertext")]
    AeadParse(#[source] postcard::Error),
}

#[cfg(test)]
mod test {
    use super::*;

    #[test]
    fn key() {
        let key = Key::from("xyzzy");
        let actual = key.as_slice();
        assert_eq!(actual.len(), Key::default().as_slice().len());
        assert!(actual.starts_with(b"xyzzy"));
        assert!(&actual[5..].iter().all(|b| *b == 0));
    }

    #[test]
    fn key_from_string() {
        let key = Key::from(&String::from("xyzzy"));
        let actual = key.as_slice();
        assert_eq!(actual.len(), Key::default().as_slice().len());
        assert!(actual.starts_with(b"xyzzy"));
        assert!(&actual[5..].iter().all(|b| *b == 0));
    }

    #[test]
    fn aead_round_trip() {
        let key = Key::default();
        let engine = Engine::new(EngineKind::Aead, key);
        let plaintext = "hello, world".as_bytes().to_vec();
        let ad = "label".as_bytes().to_vec();
        let ciphertext = engine.encrypt(&plaintext, &ad).unwrap();
        let decrypted = engine.decrypt(&ciphertext, &ad).unwrap();
        assert_eq!(plaintext, decrypted);
    }
}
