Private/OktaRSAEncryptionProvider.cs

/*
 * Okta Management
 *
 * Allows customers to easily access the Okta Management APIs
 *
 * The version of the OpenAPI document: 3.0.0
 * Contact: devex-public@okta.com
 * Generated by: https://github.com/openapitools/openapi-generator.git
 */
using System;
using System.Collections.Generic;
using System.IO;
using System.Net;
using System.Runtime.InteropServices;
using System.Security;
using System.Security.Cryptography;
using System.Text;
 
namespace RSAEncryption
{
    /// <summary>
    /// A RSA enccryption provider.
    /// </summary>
    public class RSAEncryptionProvider
    {
        /// <summary>
        /// Get the RSA provider from the PEM file.
        /// </summary>
        /// <param name="pemfile">PEM file.</param>
        /// <param name="keyPassPhrase">Key pass phrase.</param>
        /// <returns>Get an instance of RSACryptoServiceProvider.</returns>
        public static RSACryptoServiceProvider GetRSAProviderFromPemFile(String pemfile,SecureString keyPassPhrase = null)
        {
            const String pempubheader = "-----BEGIN PUBLIC KEY-----";
            const String pempubfooter = "-----END PUBLIC KEY-----";
            bool isPrivateKeyFile = true;
            byte[] pemkey = null;
 
            if (!File.Exists(pemfile))
            {
                throw new Exception("private key file does not exist.");
            }
            string pemstr = File.ReadAllText(pemfile).Trim();
 
            if (pemstr.StartsWith(pempubheader) && pemstr.EndsWith(pempubfooter))
            {
                isPrivateKeyFile = false;
            }
 
            if (isPrivateKeyFile)
            {
                pemkey = ConvertPrivateKeyToBytes(pemstr,keyPassPhrase);
                if (pemkey == null)
                {
                    return null;
                }
                return DecodeRSAPrivateKey(pemkey);
            }
            return null ;
        }
 
        /// <summary>
        /// Convert the private key to bytes.
        /// </summary>
        /// <param name="instr">Private key.</param>
        /// <param name="keyPassPhrase">Key pass phrase.</param>
        /// <returns>The private key in the form of bytes.</returns>
        static byte[] ConvertPrivateKeyToBytes(String instr, SecureString keyPassPhrase = null)
        {
            const String pemprivheader = "-----BEGIN RSA PRIVATE KEY-----";
            const String pemprivfooter = "-----END RSA PRIVATE KEY-----";
            String pemstr = instr.Trim();
            byte[] binkey;
 
            if (!pemstr.StartsWith(pemprivheader) || !pemstr.EndsWith(pemprivfooter))
            {
                return null;
            }
 
            StringBuilder sb = new StringBuilder(pemstr);
            sb.Replace(pemprivheader, "");
            sb.Replace(pemprivfooter, "");
            String pvkstr = sb.ToString().Trim();
 
            try
            { // if there are no PEM encryption info lines, this is an UNencrypted PEM private key
                binkey = Convert.FromBase64String(pvkstr);
                return binkey;
            }
            catch (System.FormatException)
            {
                StringReader str = new StringReader(pvkstr);
 
                //-------- read PEM encryption info. lines and extract salt -----
                if (!str.ReadLine().StartsWith("Proc-Type: 4,ENCRYPTED"))
                {
                    return null;
                }
                String saltline = str.ReadLine();
                if (!saltline.StartsWith("DEK-Info: DES-EDE3-CBC,"))
                {
                    return null;
                }
                String saltstr = saltline.Substring(saltline.IndexOf(",") + 1).Trim();
                byte[] salt = new byte[saltstr.Length / 2];
                for (int i = 0; i < salt.Length; i++)
                {
                    salt[i] = Convert.ToByte(saltstr.Substring(i * 2, 2), 16);
                }
                if (str.ReadLine() != "")
                {
                    return null;
                }
 
                //------ remaining b64 data is encrypted RSA key ----
                String encryptedstr = str.ReadToEnd();
 
                try
                { //should have b64 encrypted RSA key now
                    binkey = Convert.FromBase64String(encryptedstr);
                }
                catch (System.FormatException)
                { //data is not in base64 format
                    return null;
                }
 
                byte[] deskey = GetEncryptedKey(salt, keyPassPhrase, 1, 2); // count=1 (for OpenSSL implementation); 2 iterations to get at least 24 bytes
                if (deskey == null)
                {
                    return null;
                }
 
                //------ Decrypt the encrypted 3des-encrypted RSA private key ------
                byte[] rsakey = DecryptKey(binkey, deskey, salt); //OpenSSL uses salt value in PEM header also as 3DES IV
                return rsakey;
            }
        }
 
        /// <summary>
        /// Decode the RSA private key.
        /// </summary>
        /// <param name="privkey">Private key.</param>
        /// <returns>An instance of RSACryptoServiceProvider.</returns>
        public static RSACryptoServiceProvider DecodeRSAPrivateKey(byte[] privkey)
        {
            byte[] bytesModules, bytesE, bytesD, bytesP, bytesQ, bytesDp, bytesDq, bytesIq;
 
            // --------- Set up stream to decode the asn.1 encoded RSA private key ------
            MemoryStream mem = new MemoryStream(privkey);
            BinaryReader binr = new BinaryReader(mem); //wrap Memory Stream with BinaryReader for easy reading
            byte bt = 0;
            ushort twobytes = 0;
            int elems = 0;
            try
            {
                twobytes = binr.ReadUInt16();
                if (twobytes == 0x8130) //data read as little endian order (actual data order for Sequence is 30 81)
                {
                    binr.ReadByte(); //advance 1 byte
                }
                else if (twobytes == 0x8230)
                {
                    binr.ReadInt16(); //advance 2 bytes
                }
                else
                {
                    return null;
                }
 
                twobytes = binr.ReadUInt16();
                if (twobytes != 0x0102) //version number
                {
                    return null;
                }
                bt = binr.ReadByte();
                if (bt != 0x00)
                {
                    return null;
                }
 
                //------ all private key components are Integer sequences ----
                elems = GetIntegerSize(binr);
                bytesModules = binr.ReadBytes(elems);
 
                elems = GetIntegerSize(binr);
                bytesE = binr.ReadBytes(elems);
 
                elems = GetIntegerSize(binr);
                bytesD = binr.ReadBytes(elems);
 
                elems = GetIntegerSize(binr);
                bytesP = binr.ReadBytes(elems);
 
                elems = GetIntegerSize(binr);
                bytesQ = binr.ReadBytes(elems);
 
                elems = GetIntegerSize(binr);
                bytesDp = binr.ReadBytes(elems);
 
                elems = GetIntegerSize(binr);
                bytesDq = binr.ReadBytes(elems);
 
                elems = GetIntegerSize(binr);
                bytesIq = binr.ReadBytes(elems);
 
                // ------- create RSACryptoServiceProvider instance and initialize with public key -----
                RSACryptoServiceProvider rsa = new RSACryptoServiceProvider();
                RSAParameters RSAparams = new RSAParameters();
                RSAparams.Modulus = bytesModules;
                RSAparams.Exponent = bytesE;
                RSAparams.D = bytesD;
                RSAparams.P = bytesP;
                RSAparams.Q = bytesQ;
                RSAparams.DP = bytesDp;
                RSAparams.DQ = bytesDq;
                RSAparams.InverseQ = bytesIq;
                rsa.ImportParameters(RSAparams);
                return rsa;
            }
            catch (Exception)
            {
                return null;
            }
            finally
            {
                binr.Close();
            }
        }
 
        private static int GetIntegerSize(BinaryReader binr)
        {
            byte bt = 0;
            byte lowbyte = 0x00;
            byte highbyte = 0x00;
            int count = 0;
            bt = binr.ReadByte();
            if (bt != 0x02) //expect integer
            {
                return 0;
            }
            bt = binr.ReadByte();
 
            if (bt == 0x81)
            {
                count = binr.ReadByte(); // data size in next byte
            }
            else if (bt == 0x82)
            {
                highbyte = binr.ReadByte(); // data size in next 2 bytes
                lowbyte = binr.ReadByte();
                byte[] modint = { lowbyte, highbyte, 0x00, 0x00 };
                count = BitConverter.ToInt32(modint, 0);
            }
            else
            {
                count = bt; // we already have the data size
            }
            while (binr.ReadByte() == 0x00)
            { //remove high order zeros in data
                count -= 1;
            }
            binr.BaseStream.Seek(-1, SeekOrigin.Current);
            //last ReadByte wasn't a removed zero, so back up a byte
            return count;
        }
 
        /// <summary>
        /// Get the encrypted key.
        /// </summary>
        /// <param name="salt">Random bytes to be added.</param>
        /// <param name="secpswd">Password.</param>
        /// <param name="count">Count.</param>
        /// <param name="miter">Miter.</param>
        /// <returns>Decrypted key.</returns>
        static byte[] GetEncryptedKey(byte[] salt, SecureString secpswd, int count, int miter)
        {
            IntPtr unmanagedPswd = IntPtr.Zero;
            const int HASHLENGTH = 16; //MD5 bytes
            byte[] keymaterial = new byte[HASHLENGTH * miter]; //to store concatenated Mi hashed results
 
            byte[] psbytes = new byte[secpswd.Length];
            unmanagedPswd = Marshal.SecureStringToGlobalAllocAnsi(secpswd);
            Marshal.Copy(unmanagedPswd, psbytes, 0, psbytes.Length);
            Marshal.ZeroFreeGlobalAllocAnsi(unmanagedPswd);
 
            // --- concatenate salt and pswd bytes into fixed data array ---
            byte[] data00 = new byte[psbytes.Length + salt.Length];
            Array.Copy(psbytes, data00, psbytes.Length); //copy the pswd bytes
            Array.Copy(salt, 0, data00, psbytes.Length, salt.Length); //concatenate the salt bytes
 
            // ---- do multi-hashing and concatenate results D1, D2 ... into keymaterial bytes ----
            MD5 md5 = new MD5CryptoServiceProvider();
            byte[] result = null;
            byte[] hashtarget = new byte[HASHLENGTH + data00.Length]; //fixed length initial hashtarget
 
            for (int j = 0; j < miter; j++)
            {
                // ---- Now hash consecutively for count times ------
                if (j == 0)
                {
                    result = data00; //initialize
                }
                else
                {
                    Array.Copy(result, hashtarget, result.Length);
                    Array.Copy(data00, 0, hashtarget, result.Length, data00.Length);
                    result = hashtarget;
                }
 
                for (int i = 0; i < count; i++)
                {
                    result = md5.ComputeHash(result);
                }
                Array.Copy(result, 0, keymaterial, j * HASHLENGTH, result.Length); //concatenate to keymaterial
            }
            byte[] deskey = new byte[24];
            Array.Copy(keymaterial, deskey, deskey.Length);
 
            Array.Clear(psbytes, 0, psbytes.Length);
            Array.Clear(data00, 0, data00.Length);
            Array.Clear(result, 0, result.Length);
            Array.Clear(hashtarget, 0, hashtarget.Length);
            Array.Clear(keymaterial, 0, keymaterial.Length);
            return deskey;
        }
 
        /// <summary>
        /// Decrypt the key.
        /// </summary>
        /// <param name="chipherData">Cipher data.</param>
        /// <param name="desKey">Key to decrypt.</param>
        /// <param name="IV">Initialization vector.</param>
        /// <returns>Decrypted key.</returns>
        static byte[] DecryptKey(byte[] cipherData, byte[] desKey, byte[] IV)
        {
            MemoryStream memst = new MemoryStream();
            TripleDES alg = TripleDES.Create();
            alg.Key = desKey;
            alg.IV = IV;
            try
            {
                CryptoStream cs = new CryptoStream(memst, alg.CreateDecryptor(), CryptoStreamMode.Write);
                cs.Write(cipherData, 0, cipherData.Length);
                cs.Close();
            }
            catch (Exception)
            {
                return null;
            }
            byte[] decryptedData = memst.ToArray();
            return decryptedData;
        }
    }
}