/*
 * Decompiled with CFR 0.152.
 */
package com.amazon.opendistroforelasticsearch.security.configuration;

import com.amazon.opendistroforelasticsearch.security.configuration.DlsFlsRequestValve;
import com.amazon.opendistroforelasticsearch.security.configuration.DlsQueryParser;
import com.amazon.opendistroforelasticsearch.security.support.HeaderHelper;
import com.amazon.opendistroforelasticsearch.security.support.OpenDistroSecurityUtils;
import com.google.common.collect.ImmutableList;
import java.lang.reflect.Field;
import java.security.AccessController;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Consumer;
import java.util.stream.StreamSupport;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.ElasticsearchSecurityException;
import org.elasticsearch.SpecialPermission;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.DocWriteRequest;
import org.elasticsearch.action.RealtimeRequest;
import org.elasticsearch.action.admin.indices.shrink.ResizeRequest;
import org.elasticsearch.action.bulk.BulkItemRequest;
import org.elasticsearch.action.bulk.BulkRequest;
import org.elasticsearch.action.bulk.BulkShardRequest;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.update.UpdateRequest;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.index.query.ParsedQuery;
import org.elasticsearch.search.DocValueFormat;
import org.elasticsearch.search.aggregations.BucketOrder;
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.InternalAggregations;
import org.elasticsearch.search.aggregations.bucket.MultiBucketsAggregation;
import org.elasticsearch.search.aggregations.bucket.terms.InternalTerms;
import org.elasticsearch.search.aggregations.bucket.terms.StringTerms;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.query.QuerySearchResult;
import org.elasticsearch.threadpool.ThreadPool;

public class DlsFlsValveImpl
implements DlsFlsRequestValve {
    private static final Logger log = LogManager.getLogger(DlsFlsValveImpl.class);

    @Override
    public boolean invoke(ActionRequest request, ActionListener<?> listener, Map<String, Set<String>> allowedFlsFields, Map<String, Set<String>> maskedFields, Map<String, Set<String>> queries) {
        SearchSourceBuilder source;
        boolean dls;
        boolean fls = allowedFlsFields != null && !allowedFlsFields.isEmpty();
        boolean masked = maskedFields != null && !maskedFields.isEmpty();
        boolean bl = dls = queries != null && !queries.isEmpty();
        if (fls || masked || dls) {
            if (request instanceof RealtimeRequest) {
                ((RealtimeRequest)request).realtime(Boolean.FALSE.booleanValue());
            }
            if (request instanceof SearchRequest) {
                ((SearchRequest)request).requestCache(Boolean.FALSE);
            }
            if (request instanceof UpdateRequest) {
                listener.onFailure((Exception)new ElasticsearchSecurityException("Update is not supported when FLS or DLS or Fieldmasking is activated", new Object[0]));
                return false;
            }
            if (request instanceof BulkRequest) {
                for (DocWriteRequest inner : ((BulkRequest)request).requests()) {
                    if (!(inner instanceof UpdateRequest)) continue;
                    listener.onFailure((Exception)new ElasticsearchSecurityException("Update is not supported when FLS or DLS or Fieldmasking is activated", new Object[0]));
                    return false;
                }
            }
            if (request instanceof BulkShardRequest) {
                for (BulkItemRequest inner : ((BulkShardRequest)request).items()) {
                    if (!(inner.request() instanceof UpdateRequest)) continue;
                    listener.onFailure((Exception)new ElasticsearchSecurityException("Update is not supported when FLS or DLS or Fieldmasking is activated", new Object[0]));
                    return false;
                }
            }
            if (request instanceof ResizeRequest) {
                listener.onFailure((Exception)new ElasticsearchSecurityException("Resize is not supported when FLS or DLS or Fieldmasking is activated", new Object[0]));
                return false;
            }
        }
        if (dls && request instanceof SearchRequest && (source = ((SearchRequest)request).source()) != null && source.profile()) {
            listener.onFailure((Exception)new ElasticsearchSecurityException("Profiling is not supported when DLS is activated", new Object[0]));
            return false;
        }
        return true;
    }

    @Override
    public void handleSearchContext(SearchContext context, ThreadPool threadPool, NamedXContentRegistry namedXContentRegistry) {
        try {
            Map queries = (Map)((Object)HeaderHelper.deserializeSafeFromHeader(threadPool.getThreadContext(), "_opendistro_security_dls_query"));
            String dlsEval = OpenDistroSecurityUtils.evalMap(queries, context.indexShard().indexSettings().getIndex().getName());
            if (dlsEval != null) {
                if (context.suggest() != null) {
                    return;
                }
                assert (context.parsedQuery() != null);
                Set unparsedDlsQueries = (Set)queries.get(dlsEval);
                if (unparsedDlsQueries != null && !unparsedDlsQueries.isEmpty()) {
                    ParsedQuery dlsQuery = DlsQueryParser.parse(unparsedDlsQueries, context.parsedQuery(), context.getQueryShardContext(), namedXContentRegistry);
                    context.parsedQuery(dlsQuery);
                    context.preProcess(true);
                }
            }
        }
        catch (Exception e) {
            throw new RuntimeException("Error evaluating dls for a search query: " + e, e);
        }
    }

    @Override
    public void onQueryPhase(QuerySearchResult queryResult) {
        InternalAggregations aggregations = (InternalAggregations)queryResult.aggregations().expand();
        assert (aggregations != null);
        queryResult.aggregations(InternalAggregations.from((List)((List)StreamSupport.stream(aggregations.spliterator(), false).map(aggregation -> DlsFlsValveImpl.aggregateBuckets((InternalAggregation)aggregation)).collect(ImmutableList.toImmutableList()))));
    }

    private static InternalAggregation aggregateBuckets(InternalAggregation aggregation) {
        StringTerms stringTerms;
        List<StringTerms.Bucket> buckets;
        if (aggregation instanceof StringTerms && (buckets = (stringTerms = (StringTerms)aggregation).getBuckets()).size() > 1) {
            buckets = DlsFlsValveImpl.mergeBuckets(buckets, StringTermsGetter.getReduceOrder(stringTerms).comparator());
            aggregation = stringTerms.create(buckets);
        }
        return aggregation;
    }

    private static List<StringTerms.Bucket> mergeBuckets(List<StringTerms.Bucket> buckets, Comparator<MultiBucketsAggregation.Bucket> comparator) {
        if (log.isDebugEnabled()) {
            log.debug("Merging buckets: {}", buckets.stream().map(b -> b.getKeyAsString()).collect(ImmutableList.toImmutableList()));
        }
        buckets.sort(comparator);
        BucketMerger merger = new BucketMerger(comparator, buckets.size());
        buckets.stream().forEach(merger);
        buckets = merger.getBuckets();
        if (log.isDebugEnabled()) {
            log.debug("New buckets: {}", buckets.stream().map(b -> b.getKeyAsString()).collect(ImmutableList.toImmutableList()));
        }
        return buckets;
    }

    private static class StringTermsGetter {
        private static final Field REDUCE_ORDER = StringTermsGetter.getField(InternalTerms.class, "reduceOrder");
        private static final Field TERM_BYTES = StringTermsGetter.getField(StringTerms.Bucket.class, "termBytes");
        private static final Field FORMAT = StringTermsGetter.getField(InternalTerms.Bucket.class, "format");

        private StringTermsGetter() {
        }

        private static <T> Field getFieldPrivileged(Class<T> cls, String name) {
            try {
                Field field = cls.getDeclaredField(name);
                field.setAccessible(true);
                return field;
            }
            catch (NoSuchFieldException | SecurityException e) {
                log.error("Failed to get class {} declared field {}", (Object)cls.getSimpleName(), (Object)name, (Object)e);
                if (e instanceof RuntimeException) {
                    throw (RuntimeException)e;
                }
                throw new RuntimeException(e);
            }
        }

        private static <T> Field getField(Class<T> cls, String name) {
            SpecialPermission.check();
            return AccessController.doPrivileged(() -> StringTermsGetter.getFieldPrivileged(cls, name));
        }

        private static <T, C> T getFieldValue(Field field, C c) {
            try {
                return (T)field.get(c);
            }
            catch (IllegalAccessException | IllegalArgumentException e) {
                log.error("Exception while getting value {} of class {}", (Object)field.getName(), (Object)c.getClass().getSimpleName(), (Object)e);
                if (e instanceof RuntimeException) {
                    throw (RuntimeException)e;
                }
                throw new RuntimeException(e);
            }
        }

        public static BucketOrder getReduceOrder(StringTerms stringTerms) {
            return (BucketOrder)StringTermsGetter.getFieldValue(REDUCE_ORDER, stringTerms);
        }

        public static BytesRef getTerm(StringTerms.Bucket bucket) {
            return (BytesRef)StringTermsGetter.getFieldValue(TERM_BYTES, bucket);
        }

        public static DocValueFormat getDocValueFormat(StringTerms.Bucket bucket) {
            return (DocValueFormat)StringTermsGetter.getFieldValue(FORMAT, bucket);
        }
    }

    private static class BucketMerger
    implements Consumer<StringTerms.Bucket> {
        private Comparator<MultiBucketsAggregation.Bucket> comparator;
        private StringTerms.Bucket bucket = null;
        private int mergeCount;
        private long mergedDocCount;
        private long mergedDocCountError;
        private boolean showDocCountError = true;
        private final ImmutableList.Builder<StringTerms.Bucket> builder;

        BucketMerger(Comparator<MultiBucketsAggregation.Bucket> comparator, int size) {
            this.comparator = Objects.requireNonNull(comparator);
            this.builder = ImmutableList.builderWithExpectedSize((int)size);
        }

        private void finalizeBucket() {
            if (this.mergeCount == 1) {
                this.builder.add((Object)this.bucket);
            } else {
                this.builder.add((Object)new StringTerms.Bucket(StringTermsGetter.getTerm(this.bucket), this.mergedDocCount, (InternalAggregations)this.bucket.getAggregations(), this.showDocCountError, this.mergedDocCountError, StringTermsGetter.getDocValueFormat(this.bucket)));
            }
        }

        private void merge(StringTerms.Bucket bucket) {
            if (this.bucket != null && (bucket == null || this.comparator.compare((MultiBucketsAggregation.Bucket)this.bucket, (MultiBucketsAggregation.Bucket)bucket) != 0)) {
                this.finalizeBucket();
                this.bucket = null;
                this.mergeCount = 0;
                this.mergedDocCount = 0L;
                this.mergedDocCountError = 0L;
                this.showDocCountError = true;
            }
        }

        public List<StringTerms.Bucket> getBuckets() {
            this.merge(null);
            return this.builder.build();
        }

        @Override
        public void accept(StringTerms.Bucket bucket) {
            this.merge(bucket);
            ++this.mergeCount;
            this.mergedDocCount += bucket.getDocCount();
            if (this.showDocCountError) {
                try {
                    this.mergedDocCountError += bucket.getDocCountError();
                }
                catch (IllegalStateException e) {
                    this.showDocCountError = false;
                }
            }
            this.bucket = bucket;
        }
    }
}

