/*
 * Decompiled with CFR 0.152.
 */
package org.apache.nifi.security.util.crypto;

import java.nio.charset.StandardCharsets;
import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import javax.crypto.Cipher;
import javax.crypto.spec.SecretKeySpec;
import org.apache.commons.codec.binary.Base64;
import org.apache.commons.lang3.StringUtils;
import org.apache.nifi.processor.exception.ProcessException;
import org.apache.nifi.security.util.EncryptionMethod;
import org.apache.nifi.security.util.crypto.AESKeyedCipherProvider;
import org.apache.nifi.security.util.crypto.Argon2SecureHasher;
import org.apache.nifi.security.util.crypto.CipherUtility;
import org.apache.nifi.security.util.crypto.KeyedCipherProvider;
import org.apache.nifi.security.util.crypto.RandomIVPBECipherProvider;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class Argon2CipherProvider
extends RandomIVPBECipherProvider {
    private static final Logger logger = LoggerFactory.getLogger(Argon2CipherProvider.class);
    private static final int DEFAULT_PARALLELISM = 8;
    private static final int DEFAULT_MEMORY = 65536;
    private static final int DEFAULT_ITERATIONS = 5;
    private static final int DEFAULT_SALT_LENGTH = 16;
    private final Integer memory;
    private final int parallelism;
    private final Integer iterations;
    private static final Pattern ARGON2_SALT_FORMAT = Pattern.compile("^\\$argon2id\\$v=19\\$m=\\d+,t=\\d+,p=\\d+\\$[\\w\\+\\/]{22}(?=\\$?)$");

    public Argon2CipherProvider() {
        this(65536, 8, 5);
    }

    public Argon2CipherProvider(Integer memory, int parallelism, Integer iterations) {
        this.memory = memory;
        this.parallelism = parallelism;
        this.iterations = iterations;
        if (memory < 65536) {
            logger.warn("The provided memory size {} KiB is below the recommended minimum {} KiB", (Object)memory, (Object)65536);
        }
        if (parallelism < 8) {
            logger.warn("The provided parallelization factor {} is below the recommended minimum {}", (Object)parallelism, (Object)8);
        }
        if (iterations < 5) {
            logger.warn("The provided iterations count {} is below the recommended minimum {}", (Object)iterations, (Object)5);
        }
    }

    @Override
    public Cipher getCipher(EncryptionMethod encryptionMethod, String password, byte[] salt, byte[] iv, int keyLength, boolean encryptMode) throws Exception {
        try {
            return this.getInitializedCipher(encryptionMethod, password, salt, iv, keyLength, encryptMode);
        }
        catch (IllegalArgumentException e) {
            throw e;
        }
        catch (Exception e) {
            throw new ProcessException("Error initializing the cipher", (Throwable)e);
        }
    }

    @Override
    Logger getLogger() {
        return logger;
    }

    @Override
    public Cipher getCipher(EncryptionMethod encryptionMethod, String password, byte[] salt, int keyLength, boolean encryptMode) throws Exception {
        return this.getCipher(encryptionMethod, password, salt, new byte[0], keyLength, encryptMode);
    }

    protected Cipher getInitializedCipher(EncryptionMethod encryptionMethod, String password, byte[] salt, byte[] iv, int keyLength, boolean encryptMode) throws Exception {
        int parallelism;
        int iterations;
        int memory;
        if (encryptionMethod == null) {
            throw new IllegalArgumentException("The encryption method must be specified");
        }
        if (!encryptionMethod.isCompatibleWithStrongKDFs()) {
            throw new IllegalArgumentException(encryptionMethod.name() + " is not compatible with Argon2");
        }
        if (StringUtils.isEmpty((CharSequence)password)) {
            throw new IllegalArgumentException("Encryption with an empty password is not supported");
        }
        String algorithm = encryptionMethod.getAlgorithm();
        String cipherName = CipherUtility.parseCipherFromAlgorithm(algorithm);
        if (!CipherUtility.isValidKeyLength(keyLength, cipherName)) {
            throw new IllegalArgumentException(keyLength + " is not a valid key length for " + cipherName);
        }
        String saltString = new String(salt, StandardCharsets.UTF_8);
        byte[] rawSalt = new byte[this.getDefaultSaltLength()];
        if (Argon2CipherProvider.isArgon2FormattedSalt(saltString)) {
            ArrayList<Integer> params = new ArrayList<Integer>(3);
            this.parseSalt(saltString, rawSalt, params);
            memory = (Integer)params.get(0);
            iterations = (Integer)params.get(1);
            parallelism = (Integer)params.get(2);
        } else {
            rawSalt = salt;
            memory = this.getMemory();
            iterations = this.getIterations();
            parallelism = this.getParallelism();
        }
        Argon2SecureHasher argon2SecureHasher = new Argon2SecureHasher(keyLength / 8, memory, parallelism, iterations);
        try {
            byte[] keyBytes = argon2SecureHasher.hashRaw(password.getBytes(StandardCharsets.UTF_8), rawSalt);
            SecretKeySpec tempKey = new SecretKeySpec(keyBytes, algorithm);
            AESKeyedCipherProvider keyedCipherProvider = new AESKeyedCipherProvider();
            return ((KeyedCipherProvider)keyedCipherProvider).getCipher(encryptionMethod, tempKey, iv, encryptMode);
        }
        catch (IllegalArgumentException e) {
            if (e.getMessage().contains("The salt length")) {
                throw new IllegalArgumentException("The raw salt must be greater than or equal to 8 bytes", e);
            }
            logger.error("Encountered an error generating the Argon2 hash", (Throwable)e);
            throw e;
        }
    }

    public static byte[] extractRawSaltFromArgon2Salt(String argon2Salt) {
        String[] saltComponents = argon2Salt.split("\\$");
        if (saltComponents.length < 4) {
            throw new IllegalArgumentException("Could not parse salt");
        }
        return Base64.decodeBase64((String)saltComponents[saltComponents.length - 1]);
    }

    public static boolean isArgon2FormattedSalt(String salt) {
        if (salt == null || salt.length() == 0) {
            throw new IllegalArgumentException("The salt cannot be empty. To generate a salt, use Argon2CipherProvider#generateSalt()");
        }
        Matcher matcher = ARGON2_SALT_FORMAT.matcher(salt);
        return matcher.find();
    }

    private void parseSalt(String argon2Salt, byte[] rawSalt, List<Integer> params) {
        if (StringUtils.isEmpty((CharSequence)argon2Salt)) {
            throw new IllegalArgumentException("Cannot parse empty salt");
        }
        byte[] salt = Argon2CipherProvider.extractRawSaltFromArgon2Salt(argon2Salt);
        if (rawSalt.length < salt.length) {
            byte[] tempBytes = new byte[salt.length];
            System.arraycopy(rawSalt, 0, tempBytes, 0, rawSalt.length);
            rawSalt = tempBytes;
        }
        System.arraycopy(salt, 0, rawSalt, 0, rawSalt.length);
        if (params == null) {
            params = new ArrayList<Integer>(3);
        }
        String[] saltComponents = argon2Salt.split("\\$");
        Map<String, String> saltParams = Arrays.stream(saltComponents[3].split(",")).collect(Collectors.toMap(pair -> pair.split("=")[0], pair -> pair.split("=")[1]));
        params.add(Integer.parseInt(saltParams.get("m")));
        params.add(Integer.parseInt(saltParams.get("t")));
        params.add(Integer.parseInt(saltParams.get("p")));
    }

    @Override
    public byte[] generateSalt() {
        byte[] rawSalt = new byte[16];
        new SecureRandom().nextBytes(rawSalt);
        return Argon2CipherProvider.formSalt(rawSalt, this.getMemory(), this.getIterations(), this.getParallelism()).getBytes(StandardCharsets.UTF_8);
    }

    public static String formSalt(byte[] rawSalt, int memory, int iterations, int parallelism) {
        StringBuilder sb = new StringBuilder("$argon2id$");
        sb.append("v=19").append("$");
        sb.append("m=").append(memory).append(",");
        sb.append("t=").append(iterations).append(",");
        sb.append("p=").append(parallelism).append("$");
        sb.append(CipherUtility.encodeBase64NoPadding(rawSalt));
        return sb.toString();
    }

    @Override
    public int getDefaultSaltLength() {
        return 16;
    }

    protected int getMemory() {
        return this.memory;
    }

    protected int getParallelism() {
        return this.parallelism;
    }

    protected int getIterations() {
        return this.iterations;
    }
}

