Private/RSAEncryptionProvider.cs

/*
 * Identity Security Cloud V3 API
 *
 * Use these APIs to interact with the Identity Security Cloud platform to achieve repeatable, automated processes with greater scalability. We encourage you to join the SailPoint Developer Community forum at https://developer.sailpoint.com/discuss to connect with other developers using our APIs.
 *
 * The version of the OpenAPI document: 3.0.0
 * Generated by: https://github.com/openapitools/openapi-generator.git
 */
using System;
using System.Collections.Generic;
using System.IO;
using System.Net;
using System.Runtime.InteropServices;
using System.Security;
using System.Security.Cryptography;
using System.Text;

namespace RSAEncryption
{
    /// <summary>
    /// A RSA enccryption provider.
    /// </summary>
    public class RSAEncryptionProvider
    {
        /// <summary>
        /// Get the RSA provider from the PEM file.
        /// </summary>
        /// <param name="pemfile">PEM file.</param>
        /// <param name="keyPassPhrase">Key pass phrase.</param>
        /// <returns>Get an instance of RSACryptoServiceProvider.</returns>
        public static RSACryptoServiceProvider GetRSAProviderFromPemFile(String pemfile,SecureString keyPassPhrase = null)
        {
            const String pempubheader = "-----BEGIN PUBLIC KEY-----";
            const String pempubfooter = "-----END PUBLIC KEY-----";
            bool isPrivateKeyFile = true;
            byte[] pemkey = null;

            if (!File.Exists(pemfile))
            {
                throw new Exception("private key file does not exist.");
            }
            string pemstr = File.ReadAllText(pemfile).Trim();

            if (pemstr.StartsWith(pempubheader) && pemstr.EndsWith(pempubfooter))
            {
                isPrivateKeyFile = false;
            }

            if (isPrivateKeyFile)
            {
                pemkey = ConvertPrivateKeyToBytes(pemstr,keyPassPhrase);
                if (pemkey == null)
                {
                    return null;
                }
                return DecodeRSAPrivateKey(pemkey);
            }
            return null ;
        }

        /// <summary>
        /// Convert the private key to bytes.
        /// </summary>
        /// <param name="instr">Private key.</param>
        /// <param name="keyPassPhrase">Key pass phrase.</param>
        /// <returns>The private key in the form of bytes.</returns>
        static byte[] ConvertPrivateKeyToBytes(String instr, SecureString keyPassPhrase = null)
        {
            const String pemprivheader = "-----BEGIN RSA PRIVATE KEY-----";
            const String pemprivfooter = "-----END RSA PRIVATE KEY-----";
            String pemstr = instr.Trim();
            byte[] binkey;

            if (!pemstr.StartsWith(pemprivheader) || !pemstr.EndsWith(pemprivfooter))
            {
                return null;
            }

            StringBuilder sb = new StringBuilder(pemstr);
            sb.Replace(pemprivheader, "");
            sb.Replace(pemprivfooter, "");
            String pvkstr = sb.ToString().Trim();

            try
            { // if there are no PEM encryption info lines, this is an UNencrypted PEM private key
                binkey = Convert.FromBase64String(pvkstr);
                return binkey;
            }
            catch (System.FormatException)
            {
                StringReader str = new StringReader(pvkstr);

                //-------- read PEM encryption info. lines and extract salt -----
                if (!str.ReadLine().StartsWith("Proc-Type: 4,ENCRYPTED"))
                {
                    return null;
                }
                String saltline = str.ReadLine();
                if (!saltline.StartsWith("DEK-Info: DES-EDE3-CBC,"))
                {
                    return null;
                }
                String saltstr = saltline.Substring(saltline.IndexOf(",") + 1).Trim();
                byte[] salt = new byte[saltstr.Length / 2];
                for (int i = 0; i < salt.Length; i++)
                {
                    salt[i] = Convert.ToByte(saltstr.Substring(i * 2, 2), 16);
                }
                if (str.ReadLine() != "")
                {
                    return null;
                }

                //------ remaining b64 data is encrypted RSA key ----
                String encryptedstr = str.ReadToEnd();

                try
                { //should have b64 encrypted RSA key now
                    binkey = Convert.FromBase64String(encryptedstr);
                }
                catch (System.FormatException)
                { //data is not in base64 format
                    return null;
                }

                byte[] deskey = GetEncryptedKey(salt, keyPassPhrase, 1, 2); // count=1 (for OpenSSL implementation); 2 iterations to get at least 24 bytes
                if (deskey == null)
                {
                    return null;
                }

                //------ Decrypt the encrypted 3des-encrypted RSA private key ------
                byte[] rsakey = DecryptKey(binkey, deskey, salt); //OpenSSL uses salt value in PEM header also as 3DES IV
                return rsakey;
            }
        }

        /// <summary>
        /// Decode the RSA private key.
        /// </summary>
        /// <param name="privkey">Private key.</param>
        /// <returns>An instance of RSACryptoServiceProvider.</returns>
        public static RSACryptoServiceProvider DecodeRSAPrivateKey(byte[] privkey)
        {
            byte[] bytesModules, bytesE, bytesD, bytesP, bytesQ, bytesDp, bytesDq, bytesIq;

            // --------- Set up stream to decode the asn.1 encoded RSA private key ------
            MemoryStream mem = new MemoryStream(privkey);
            BinaryReader binr = new BinaryReader(mem); //wrap Memory Stream with BinaryReader for easy reading
            byte bt = 0;
            ushort twobytes = 0;
            int elems = 0;
            try
            {
                twobytes = binr.ReadUInt16();
                if (twobytes == 0x8130) //data read as little endian order (actual data order for Sequence is 30 81)
                {
                    binr.ReadByte(); //advance 1 byte
                }
                else if (twobytes == 0x8230)
                {
                    binr.ReadInt16(); //advance 2 bytes
                }
                else
                {
                    return null;
                }

                twobytes = binr.ReadUInt16();
                if (twobytes != 0x0102) //version number
                {
                    return null;
                }
                bt = binr.ReadByte();
                if (bt != 0x00)
                {
                    return null;
                }

                //------ all private key components are Integer sequences ----
                elems = GetIntegerSize(binr);
                bytesModules = binr.ReadBytes(elems);

                elems = GetIntegerSize(binr);
                bytesE = binr.ReadBytes(elems);

                elems = GetIntegerSize(binr);
                bytesD = binr.ReadBytes(elems);

                elems = GetIntegerSize(binr);
                bytesP = binr.ReadBytes(elems);

                elems = GetIntegerSize(binr);
                bytesQ = binr.ReadBytes(elems);

                elems = GetIntegerSize(binr);
                bytesDp = binr.ReadBytes(elems);

                elems = GetIntegerSize(binr);
                bytesDq = binr.ReadBytes(elems);

                elems = GetIntegerSize(binr);
                bytesIq = binr.ReadBytes(elems);

                // ------- create RSACryptoServiceProvider instance and initialize with public key -----
                RSACryptoServiceProvider rsa = new RSACryptoServiceProvider();
                RSAParameters RSAparams = new RSAParameters();
                RSAparams.Modulus = bytesModules;
                RSAparams.Exponent = bytesE;
                RSAparams.D = bytesD;
                RSAparams.P = bytesP;
                RSAparams.Q = bytesQ;
                RSAparams.DP = bytesDp;
                RSAparams.DQ = bytesDq;
                RSAparams.InverseQ = bytesIq;
                rsa.ImportParameters(RSAparams);
                return rsa;
            }
            catch (Exception)
            {
                return null;
            }
            finally
            {
                binr.Close();
            }
        }

        private static int GetIntegerSize(BinaryReader binr)
        {
            byte bt = 0;
            byte lowbyte = 0x00;
            byte highbyte = 0x00;
            int count = 0;
            bt = binr.ReadByte();
            if (bt != 0x02) //expect integer
            {
                return 0;
            }
            bt = binr.ReadByte();

            if (bt == 0x81)
            {
                count = binr.ReadByte(); // data size in next byte
            }
            else if (bt == 0x82)
            {
                highbyte = binr.ReadByte(); // data size in next 2 bytes
                lowbyte = binr.ReadByte();
                byte[] modint = { lowbyte, highbyte, 0x00, 0x00 };
                count = BitConverter.ToInt32(modint, 0);
            }
            else
            {
                count = bt; // we already have the data size
            }
            while (binr.ReadByte() == 0x00)
            { //remove high order zeros in data
                count -= 1;
            }
            binr.BaseStream.Seek(-1, SeekOrigin.Current);
            //last ReadByte wasn't a removed zero, so back up a byte
            return count;
        }

        /// <summary>
        /// Get the encrypted key.
        /// </summary>
        /// <param name="salt">Random bytes to be added.</param>
        /// <param name="secpswd">Password.</param>
        /// <param name="count">Count.</param>
        /// <param name="miter">Miter.</param>
        /// <returns>Decrypted key.</returns>
        static byte[] GetEncryptedKey(byte[] salt, SecureString secpswd, int count, int miter)
        {
            IntPtr unmanagedPswd = IntPtr.Zero;
            const int HASHLENGTH = 16; //MD5 bytes
            byte[] keymaterial = new byte[HASHLENGTH * miter]; //to store concatenated Mi hashed results

            byte[] psbytes = new byte[secpswd.Length];
            unmanagedPswd = Marshal.SecureStringToGlobalAllocAnsi(secpswd);
            Marshal.Copy(unmanagedPswd, psbytes, 0, psbytes.Length);
            Marshal.ZeroFreeGlobalAllocAnsi(unmanagedPswd);

            // --- concatenate salt and pswd bytes into fixed data array ---
            byte[] data00 = new byte[psbytes.Length + salt.Length];
            Array.Copy(psbytes, data00, psbytes.Length); //copy the pswd bytes
            Array.Copy(salt, 0, data00, psbytes.Length, salt.Length); //concatenate the salt bytes

            // ---- do multi-hashing and concatenate results D1, D2 ... into keymaterial bytes ----
            MD5 md5 = new MD5CryptoServiceProvider();
            byte[] result = null;
            byte[] hashtarget = new byte[HASHLENGTH + data00.Length]; //fixed length initial hashtarget

            for (int j = 0; j < miter; j++)
            {
                // ---- Now hash consecutively for count times ------
                if (j == 0)
                {
                    result = data00; //initialize
                }
                else
                {
                    Array.Copy(result, hashtarget, result.Length);
                    Array.Copy(data00, 0, hashtarget, result.Length, data00.Length);
                    result = hashtarget;
                }

                for (int i = 0; i < count; i++)
                {
                    result = md5.ComputeHash(result);
                }
                Array.Copy(result, 0, keymaterial, j * HASHLENGTH, result.Length); //concatenate to keymaterial
            }
            byte[] deskey = new byte[24];
            Array.Copy(keymaterial, deskey, deskey.Length);

            Array.Clear(psbytes, 0, psbytes.Length);
            Array.Clear(data00, 0, data00.Length);
            Array.Clear(result, 0, result.Length);
            Array.Clear(hashtarget, 0, hashtarget.Length);
            Array.Clear(keymaterial, 0, keymaterial.Length);
            return deskey;
        }

        /// <summary>
        /// Decrypt the key.
        /// </summary>
        /// <param name="chipherData">Cipher data.</param>
        /// <param name="desKey">Key to decrypt.</param>
        /// <param name="IV">Initialization vector.</param>
        /// <returns>Decrypted key.</returns>
        static byte[] DecryptKey(byte[] cipherData, byte[] desKey, byte[] IV)
        {
            MemoryStream memst = new MemoryStream();
            TripleDES alg = TripleDES.Create();
            alg.Key = desKey;
            alg.IV = IV;
            try
            {
                CryptoStream cs = new CryptoStream(memst, alg.CreateDecryptor(), CryptoStreamMode.Write);
                cs.Write(cipherData, 0, cipherData.Length);
                cs.Close();
            }
            catch (Exception)
            {
                return null;
            }
            byte[] decryptedData = memst.ToArray();
            return decryptedData;
        }
    }
}