package com.shinyhut.vernacular.protocol.auth;

import com.shinyhut.vernacular.client.VncSession;
import com.shinyhut.vernacular.client.exceptions.SecurityTypeFailedException;
import com.shinyhut.vernacular.client.exceptions.VncException;
import com.shinyhut.vernacular.protocol.messages.SecurityResult;
import com.shinyhut.vernacular.utils.AesEaxInputStream;
import com.shinyhut.vernacular.utils.AesEaxOutputStream;
import com.shinyhut.vernacular.utils.ByteUtils;
import com.shinyhut.vernacular.utils.CryptoUtils;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.math.BigInteger;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.security.InvalidKeyException;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.MessageDigest;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.security.SecureRandom;
import java.security.interfaces.RSAPublicKey;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.RSAPublicKeySpec;
import java.util.Arrays;
import javax.crypto.BadPaddingException;
import javax.crypto.Cipher;
import javax.crypto.IllegalBlockSizeException;
import org.eclipse.swt.internal.win32.OS;

/* loaded from: input_file:BOOT-INF/core/vernacular-vnc-1.17.jar:com/shinyhut/vernacular/protocol/auth/RsaAesAuthenticationHandler.class */
public class RsaAesAuthenticationHandler implements SecurityHandler {
    private static final int MIN_KEY_LENGTH = 1024;
    private static final int MAX_KEY_LENGTH = 8192;
    private final int keySize;
    private final MessageDigest digest;
    private int subtype;
    private PrivateKey clientKey;
    private PublicKey clientPublicKey;
    private PublicKey serverKey;
    private int serverKeyLength;
    private byte[] serverKeyN;
    private byte[] serverKeyE;
    private int clientKeyLength;
    private byte[] clientKeyN;
    private byte[] clientKeyE;
    private byte[] serverRandom;
    private byte[] clientRandom;
    private InputStream rawInput;
    private OutputStream rawOutput;
    private AesEaxInputStream encryptedInput;
    private AesEaxOutputStream encryptedOutput;
    private VncSession session;
    private final int authCode;

    public static RsaAesAuthenticationHandler RA2(int i) throws VncException {
        return new RsaAesAuthenticationHandler(128, CryptoUtils.sha1(), i);
    }

    public static RsaAesAuthenticationHandler RA2_256(int i) throws VncException {
        return new RsaAesAuthenticationHandler(256, CryptoUtils.sha256(), i);
    }

    private RsaAesAuthenticationHandler(int i, MessageDigest messageDigest, int i2) {
        this.keySize = i;
        this.digest = messageDigest;
        this.authCode = i2;
    }

    @Override // com.shinyhut.vernacular.protocol.auth.SecurityHandler
    public SecurityResult authenticate(VncSession vncSession) throws VncException, IOException {
        this.session = vncSession;
        this.rawInput = vncSession.getInputStream();
        this.rawOutput = vncSession.getOutputStream();
        if (!vncSession.getProtocolVersion().equals(3, 3)) {
            requestAuthentication();
        }
        readPublicKey();
        verifyServer();
        writePublicKey();
        writeRandom();
        readRandom();
        setCipher();
        writeHash();
        readHash();
        readSubtype();
        writeCredentials();
        return SecurityResult.decode(this.rawInput, vncSession.getProtocolVersion());
    }

    private void requestAuthentication() throws IOException {
        this.rawOutput.write(this.authCode);
    }

    private void readPublicKey() throws VncException, IOException {
        DataInputStream dataInputStream = new DataInputStream(this.rawInput);
        this.serverKeyLength = dataInputStream.readInt();
        if (this.serverKeyLength < 1024) {
            throw new SecurityTypeFailedException("Server key is too short");
        }
        if (this.serverKeyLength > 8192) {
            throw new SecurityTypeFailedException("Server key is too long");
        }
        int i = (this.serverKeyLength + 7) / 8;
        this.serverKeyN = new byte[i];
        this.serverKeyE = new byte[i];
        dataInputStream.readFully(this.serverKeyN);
        dataInputStream.readFully(this.serverKeyE);
        try {
            this.serverKey = CryptoUtils.rsaKeyFactory().generatePublic(new RSAPublicKeySpec(new BigInteger(1, this.serverKeyN), new BigInteger(1, this.serverKeyE)));
        } catch (InvalidKeySpecException e) {
            throw new SecurityTypeFailedException("Server key is invalid");
        }
    }

    private void verifyServer() throws VncException, IOException {
        MessageDigest sha1 = CryptoUtils.sha1();
        sha1.update(new byte[]{(byte) ((this.serverKeyLength & OS.CLR_DEFAULT) >> 24), (byte) ((this.serverKeyLength & 16711680) >> 16), (byte) ((this.serverKeyLength & 65280) >> 8), (byte) (this.serverKeyLength & 255)});
        sha1.update(this.serverKeyN);
        sha1.update(this.serverKeyE);
        sha1.digest();
    }

    private void writePublicKey() throws VncException, IOException {
        DataOutputStream dataOutputStream = new DataOutputStream(this.rawOutput);
        this.clientKeyLength = this.serverKeyLength;
        KeyPairGenerator rsaKeyPairGenerator = CryptoUtils.rsaKeyPairGenerator();
        rsaKeyPairGenerator.initialize(this.clientKeyLength);
        KeyPair generateKeyPair = rsaKeyPairGenerator.generateKeyPair();
        this.clientKey = generateKeyPair.getPrivate();
        this.clientPublicKey = generateKeyPair.getPublic();
        RSAPublicKey rSAPublicKey = (RSAPublicKey) this.clientPublicKey;
        BigInteger modulus = rSAPublicKey.getModulus();
        BigInteger publicExponent = rSAPublicKey.getPublicExponent();
        this.clientKeyN = ByteUtils.bigIntToBytes(modulus, (this.clientKeyLength + 7) / 8, false);
        this.clientKeyE = ByteUtils.bigIntToBytes(publicExponent, (this.clientKeyLength + 7) / 8, false);
        dataOutputStream.writeInt(this.clientKeyLength);
        dataOutputStream.write(this.clientKeyN);
        dataOutputStream.write(this.clientKeyE);
    }

    private void writeRandom() throws VncException, IOException {
        DataOutputStream dataOutputStream = new DataOutputStream(this.rawOutput);
        SecureRandom secureRandom = new SecureRandom();
        this.clientRandom = new byte[this.keySize / 8];
        secureRandom.nextBytes(this.clientRandom);
        try {
            Cipher rsaEcbPkcs1PaddingCipher = CryptoUtils.rsaEcbPkcs1PaddingCipher();
            rsaEcbPkcs1PaddingCipher.init(1, this.serverKey);
            byte[] doFinal = rsaEcbPkcs1PaddingCipher.doFinal(this.clientRandom);
            dataOutputStream.writeShort(doFinal.length);
            dataOutputStream.write(doFinal);
        } catch (InvalidKeyException | BadPaddingException | IllegalBlockSizeException e) {
            throw new SecurityTypeFailedException("Failed to encrypt random");
        }
    }

    private void readRandom() throws VncException, IOException {
        DataInputStream dataInputStream = new DataInputStream(this.rawInput);
        int readShort = dataInputStream.readShort();
        if (readShort != this.clientKeyN.length) {
            throw new SecurityTypeFailedException("Client key length doesn't match");
        }
        byte[] bArr = new byte[readShort];
        dataInputStream.readFully(bArr);
        try {
            Cipher rsaEcbPkcs1PaddingCipher = CryptoUtils.rsaEcbPkcs1PaddingCipher();
            rsaEcbPkcs1PaddingCipher.init(2, this.clientKey);
            this.serverRandom = rsaEcbPkcs1PaddingCipher.doFinal(bArr);
            if (this.serverRandom.length != this.keySize / 8) {
                throw new SecurityTypeFailedException("Server random length doesn't match");
            }
        } catch (InvalidKeyException | BadPaddingException | IllegalBlockSizeException e) {
            throw new SecurityTypeFailedException("Failed to decrypt server random");
        }
    }

    private void setCipher() throws VncException, IOException {
        this.digest.update(this.clientRandom);
        this.digest.update(this.serverRandom);
        this.encryptedInput = new AesEaxInputStream(Arrays.copyOfRange(this.digest.digest(), 0, this.keySize / 8), this.rawInput);
        this.digest.reset();
        this.digest.update(this.serverRandom);
        this.digest.update(this.clientRandom);
        this.encryptedOutput = new AesEaxOutputStream(Arrays.copyOfRange(this.digest.digest(), 0, this.keySize / 8), this.rawOutput);
    }

    private void writeHash() throws VncException, IOException {
        byte[] bArr = {(byte) ((this.serverKeyLength & OS.CLR_DEFAULT) >> 24), (byte) ((this.serverKeyLength & 16711680) >> 16), (byte) ((this.serverKeyLength & 65280) >> 8), (byte) (this.serverKeyLength & 255)};
        this.digest.update(new byte[]{(byte) ((this.clientKeyLength & OS.CLR_DEFAULT) >> 24), (byte) ((this.clientKeyLength & 16711680) >> 16), (byte) ((this.clientKeyLength & 65280) >> 8), (byte) (this.clientKeyLength & 255)});
        this.digest.update(this.clientKeyN);
        this.digest.update(this.clientKeyE);
        this.digest.update(bArr);
        this.digest.update(this.serverKeyN);
        this.digest.update(this.serverKeyE);
        this.encryptedOutput.write(this.digest.digest());
    }

    void readHash() throws VncException, IOException {
        byte[] bArr = {(byte) ((this.serverKeyLength & OS.CLR_DEFAULT) >> 24), (byte) ((this.serverKeyLength & 16711680) >> 16), (byte) ((this.serverKeyLength & 65280) >> 8), (byte) (this.serverKeyLength & 255)};
        byte[] bArr2 = {(byte) ((this.clientKeyLength & OS.CLR_DEFAULT) >> 24), (byte) ((this.clientKeyLength & 16711680) >> 16), (byte) ((this.clientKeyLength & 65280) >> 8), (byte) (this.clientKeyLength & 255)};
        this.digest.update(bArr);
        this.digest.update(this.serverKeyN);
        this.digest.update(this.serverKeyE);
        this.digest.update(bArr2);
        this.digest.update(this.clientKeyN);
        this.digest.update(this.clientKeyE);
        if (!Arrays.equals(this.encryptedInput.read(), this.digest.digest())) {
            throw new SecurityTypeFailedException("Hash doesn't match");
        }
    }

    private void readSubtype() throws VncException, IOException {
        this.subtype = this.encryptedInput.read()[0];
        if (this.subtype != 1 && this.subtype != 2) {
            throw new SecurityTypeFailedException("Unknown RSA-AES authentication subtype");
        }
    }

    private void writeCredentials() throws VncException, IOException {
        byte[] bytes = this.session.getConfig().getPasswordSupplier().get().getBytes(StandardCharsets.UTF_8);
        ByteBuffer allocate = ByteBuffer.allocate(2 + bytes.length);
        allocate.put((byte) 0);
        allocate.put((byte) bytes.length);
        allocate.put(bytes);
        this.encryptedOutput.write(allocate.array());
    }
}
