/*
 * Decompiled with CFR 0.152.
 */
package org.apache.lucene.classification.document;

import java.io.IOException;
import java.io.Reader;
import java.io.StringReader;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.classification.ClassificationResult;
import org.apache.lucene.classification.KNearestNeighborClassifier;
import org.apache.lucene.classification.document.DocumentClassifier;
import org.apache.lucene.document.Document;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.WildcardQuery;
import org.apache.lucene.search.similarities.Similarity;
import org.apache.lucene.util.BytesRef;

public class KNearestNeighborDocumentClassifier
extends KNearestNeighborClassifier
implements DocumentClassifier<BytesRef> {
    protected final Map<String, Analyzer> field2analyzer;

    public KNearestNeighborDocumentClassifier(IndexReader indexReader, Similarity similarity, Query query, int k, int minDocsFreq, int minTermFreq, String classFieldName, Map<String, Analyzer> field2analyzer, String ... textFieldNames) throws IOException {
        super(indexReader, similarity, null, query, k, minDocsFreq, minTermFreq, classFieldName, textFieldNames);
        this.field2analyzer = field2analyzer;
    }

    @Override
    public ClassificationResult<BytesRef> assignClass(Document document) throws IOException {
        return this.classifyFromTopDocs(this.knnSearch(document));
    }

    @Override
    public List<ClassificationResult<BytesRef>> getClasses(Document document) throws IOException {
        TopDocs knnResults = this.knnSearch(document);
        List<ClassificationResult<BytesRef>> assignedClasses = this.buildListFromTopDocs(knnResults);
        Collections.sort(assignedClasses);
        return assignedClasses;
    }

    @Override
    public List<ClassificationResult<BytesRef>> getClasses(Document document, int max) throws IOException {
        TopDocs knnResults = this.knnSearch(document);
        List<ClassificationResult<BytesRef>> assignedClasses = this.buildListFromTopDocs(knnResults);
        Collections.sort(assignedClasses);
        max = Math.min(max, assignedClasses.size());
        return assignedClasses.subList(0, max);
    }

    private TopDocs knnSearch(Document document) throws IOException {
        BooleanQuery.Builder mltQuery = new BooleanQuery.Builder();
        for (String fieldName : this.textFieldNames) {
            String boost = null;
            if (fieldName.contains("^")) {
                String[] field2boost = fieldName.split("\\^");
                fieldName = field2boost[0];
                boost = field2boost[1];
            }
            String[] fieldValues = document.getValues(fieldName);
            this.mlt.setBoost(true);
            if (boost != null) {
                this.mlt.setBoostFactor(Float.parseFloat(boost));
            }
            this.mlt.setAnalyzer(this.field2analyzer.get(fieldName));
            for (String fieldContent : fieldValues) {
                mltQuery.add(new BooleanClause(this.mlt.like(fieldName, new Reader[]{new StringReader(fieldContent)}), BooleanClause.Occur.SHOULD));
            }
            this.mlt.setBoostFactor(1.0f);
        }
        WildcardQuery classFieldQuery = new WildcardQuery(new Term(this.classFieldName, "*"));
        mltQuery.add(new BooleanClause((Query)classFieldQuery, BooleanClause.Occur.MUST));
        if (this.query != null) {
            mltQuery.add(this.query, BooleanClause.Occur.MUST);
        }
        return this.indexSearcher.search((Query)mltQuery.build(), this.k);
    }
}

