/*
 * Decompiled with CFR 0.152.
 */
package com.amazon.opendistroforelasticsearch.knn.plugin.script;

import com.amazon.opendistroforelasticsearch.knn.index.KNNVectorScriptDocValues;
import java.math.BigInteger;
import java.util.List;
import java.util.Objects;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class KNNScoringUtil {
    private static Logger logger = LogManager.getLogger(KNNScoringUtil.class);

    private static void requireEqualDimension(float[] queryVector, float[] inputVector) {
        Objects.requireNonNull(queryVector);
        Objects.requireNonNull(inputVector);
        if (queryVector.length != inputVector.length) {
            String errorMessage = String.format("query vector dimension mismatch. Expected: %d, Given: %d", inputVector.length, queryVector.length);
            throw new IllegalArgumentException(errorMessage);
        }
    }

    public static float l2Squared(float[] queryVector, float[] inputVector) {
        KNNScoringUtil.requireEqualDimension(queryVector, inputVector);
        float squaredDistance = 0.0f;
        for (int i = 0; i < inputVector.length; ++i) {
            float diff = queryVector[i] - inputVector[i];
            squaredDistance += diff * diff;
        }
        return squaredDistance;
    }

    private static float[] toFloat(List<Number> inputVector) {
        Objects.requireNonNull(inputVector);
        float[] value = new float[inputVector.size()];
        int index = 0;
        for (Number val : inputVector) {
            value[index++] = val.floatValue();
        }
        return value;
    }

    public static float l2Squared(List<Number> queryVector, KNNVectorScriptDocValues docValues) {
        return KNNScoringUtil.l2Squared(KNNScoringUtil.toFloat(queryVector), docValues.getValue());
    }

    public static float cosinesimilOptimized(float[] queryVector, float[] inputVector, float normQueryVector) {
        KNNScoringUtil.requireEqualDimension(queryVector, inputVector);
        float dotProduct = 0.0f;
        float normInputVector = 0.0f;
        for (int i = 0; i < queryVector.length; ++i) {
            dotProduct += queryVector[i] * inputVector[i];
            normInputVector += inputVector[i] * inputVector[i];
        }
        float normalizedProduct = normQueryVector * normInputVector;
        if (normalizedProduct == 0.0f) {
            logger.debug("Invalid vectors for cosine. Returning minimum score to put this result to end");
            return Float.MIN_VALUE;
        }
        return (float)((double)dotProduct / Math.sqrt(normalizedProduct));
    }

    public static float cosineSimilarity(List<Number> queryVector, KNNVectorScriptDocValues docValues, Number queryVectorMagnitude) {
        return KNNScoringUtil.cosinesimilOptimized(KNNScoringUtil.toFloat(queryVector), docValues.getValue(), queryVectorMagnitude.floatValue());
    }

    public static float cosinesimil(float[] queryVector, float[] inputVector) {
        KNNScoringUtil.requireEqualDimension(queryVector, inputVector);
        float dotProduct = 0.0f;
        float normQueryVector = 0.0f;
        float normInputVector = 0.0f;
        for (int i = 0; i < queryVector.length; ++i) {
            dotProduct += queryVector[i] * inputVector[i];
            normQueryVector += queryVector[i] * queryVector[i];
            normInputVector += inputVector[i] * inputVector[i];
        }
        float normalizedProduct = normQueryVector * normInputVector;
        if (normalizedProduct == 0.0f) {
            logger.debug("Invalid vectors for cosine. Returning minimum score to put this result to end");
            return Float.MIN_VALUE;
        }
        return (float)((double)dotProduct / Math.sqrt(normalizedProduct));
    }

    public static float cosineSimilarity(List<Number> queryVector, KNNVectorScriptDocValues docValues) {
        return KNNScoringUtil.cosinesimil(KNNScoringUtil.toFloat(queryVector), docValues.getValue());
    }

    public static float calculateHammingBit(BigInteger queryBigInteger, BigInteger inputBigInteger) {
        return inputBigInteger.xor(queryBigInteger).bitCount();
    }

    public static float calculateHammingBit(Long queryLong, Long inputLong) {
        return Long.bitCount(queryLong ^ inputLong);
    }

    public static float l1Norm(float[] queryVector, float[] inputVector) {
        KNNScoringUtil.requireEqualDimension(queryVector, inputVector);
        float distance = 0.0f;
        for (int i = 0; i < inputVector.length; ++i) {
            float diff = queryVector[i] - inputVector[i];
            distance += Math.abs(diff);
        }
        return distance;
    }

    public static float l1Norm(List<Number> queryVector, KNNVectorScriptDocValues docValues) {
        return KNNScoringUtil.l1Norm(KNNScoringUtil.toFloat(queryVector), docValues.getValue());
    }
}

