/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.model.azure;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.knuddels.jtokkit.Encodings;
import com.knuddels.jtokkit.api.Encoding;
import com.knuddels.jtokkit.api.IntArrayList;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.Content;
import dev.langchain4j.data.message.ImageContent;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.TextContent;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.internal.Exceptions;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.model.Tokenizer;
import dev.langchain4j.model.azure.AzureOpenAiChatModelName;
import dev.langchain4j.model.azure.AzureOpenAiEmbeddingModelName;
import dev.langchain4j.model.azure.AzureOpenAiLanguageModelName;
import dev.langchain4j.model.chat.request.json.JsonArraySchema;
import dev.langchain4j.model.chat.request.json.JsonBooleanSchema;
import dev.langchain4j.model.chat.request.json.JsonEnumSchema;
import dev.langchain4j.model.chat.request.json.JsonIntegerSchema;
import dev.langchain4j.model.chat.request.json.JsonNumberSchema;
import dev.langchain4j.model.chat.request.json.JsonObjectSchema;
import dev.langchain4j.model.chat.request.json.JsonSchemaElement;
import dev.langchain4j.model.chat.request.json.JsonStringSchema;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Supplier;

public class AzureOpenAiTokenizer
implements Tokenizer {
    private final String modelName;
    private final Optional<Encoding> encoding;
    private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();

    public AzureOpenAiTokenizer() {
        this(AzureOpenAiChatModelName.GPT_3_5_TURBO.modelType());
    }

    public AzureOpenAiTokenizer(AzureOpenAiChatModelName modelName) {
        this(modelName.modelType());
    }

    public AzureOpenAiTokenizer(AzureOpenAiEmbeddingModelName modelName) {
        this(modelName.modelType());
    }

    public AzureOpenAiTokenizer(AzureOpenAiLanguageModelName modelName) {
        this(modelName.modelType());
    }

    public AzureOpenAiTokenizer(String modelName) {
        this.modelName = ValidationUtils.ensureNotBlank((String)modelName, (String)"modelName");
        this.encoding = Encodings.newLazyEncodingRegistry().getEncodingForModel(modelName);
    }

    public int estimateTokenCountInText(String text) {
        return this.encoding.orElseThrow(this.unknownModelException()).countTokensOrdinary(text);
    }

    public int estimateTokenCountInMessage(ChatMessage message) {
        int tokenCount = 1;
        tokenCount += this.extraTokensPerMessage();
        if (message instanceof SystemMessage) {
            tokenCount += this.estimateTokenCountIn((SystemMessage)message);
        } else if (message instanceof UserMessage) {
            tokenCount += this.estimateTokenCountIn((UserMessage)message);
        } else if (message instanceof AiMessage) {
            tokenCount += this.estimateTokenCountIn((AiMessage)message);
        } else if (message instanceof ToolExecutionResultMessage) {
            tokenCount += this.estimateTokenCountIn((ToolExecutionResultMessage)message);
        } else {
            throw new IllegalArgumentException("Unknown message type: " + String.valueOf(message));
        }
        return tokenCount;
    }

    private int estimateTokenCountIn(SystemMessage systemMessage) {
        return this.estimateTokenCountInText(systemMessage.text());
    }

    private int estimateTokenCountIn(UserMessage userMessage) {
        int tokenCount = 0;
        for (Content content : userMessage.contents()) {
            if (content instanceof TextContent) {
                tokenCount += this.estimateTokenCountInText(((TextContent)content).text());
                continue;
            }
            if (content instanceof ImageContent) {
                tokenCount += 85;
                continue;
            }
            throw Exceptions.illegalArgument((String)("Unknown content type: " + String.valueOf(content)), (Object[])new Object[0]);
        }
        if (userMessage.name() != null && !this.modelName.equals(AzureOpenAiChatModelName.GPT_4_VISION_PREVIEW.toString())) {
            tokenCount += this.extraTokensPerName();
            tokenCount += this.estimateTokenCountInText(userMessage.name());
        }
        return tokenCount;
    }

    private int estimateTokenCountIn(AiMessage aiMessage) {
        int tokenCount = 0;
        if (aiMessage.text() != null) {
            tokenCount += this.estimateTokenCountInText(aiMessage.text());
        }
        if (aiMessage.toolExecutionRequests() != null) {
            tokenCount = this.isOneOfLatestModels() ? (tokenCount += 6) : (tokenCount += 3);
            if (aiMessage.toolExecutionRequests().size() == 1) {
                --tokenCount;
                ToolExecutionRequest toolExecutionRequest = (ToolExecutionRequest)aiMessage.toolExecutionRequests().get(0);
                tokenCount += this.estimateTokenCountInText(toolExecutionRequest.name()) * 2;
                tokenCount += this.estimateTokenCountInText(toolExecutionRequest.arguments());
            } else {
                tokenCount += 15;
                for (ToolExecutionRequest toolExecutionRequest : aiMessage.toolExecutionRequests()) {
                    Map arguments;
                    tokenCount += 7;
                    tokenCount += this.estimateTokenCountInText(toolExecutionRequest.name());
                    try {
                        arguments = (Map)OBJECT_MAPPER.readValue(toolExecutionRequest.arguments(), Map.class);
                    }
                    catch (JsonProcessingException e) {
                        throw new RuntimeException(e);
                    }
                    for (Map.Entry argument : arguments.entrySet()) {
                        tokenCount += 2;
                        tokenCount += this.estimateTokenCountInText(argument.getKey().toString());
                        tokenCount += this.estimateTokenCountInText(argument.getValue().toString());
                    }
                }
            }
        }
        return tokenCount;
    }

    private int estimateTokenCountIn(ToolExecutionResultMessage toolExecutionResultMessage) {
        return this.estimateTokenCountInText(toolExecutionResultMessage.text());
    }

    private int extraTokensPerMessage() {
        if (this.modelName.equals(AzureOpenAiChatModelName.GPT_3_5_TURBO_0301.modelName())) {
            return 4;
        }
        return 3;
    }

    private int extraTokensPerName() {
        if (this.modelName.equals(AzureOpenAiChatModelName.GPT_3_5_TURBO_0301.toString())) {
            return -1;
        }
        return 1;
    }

    public int estimateTokenCountInMessages(Iterable<ChatMessage> messages) {
        int tokenCount = 3;
        for (ChatMessage message : messages) {
            tokenCount += this.estimateTokenCountInMessage(message);
        }
        return tokenCount;
    }

    public int estimateTokenCountInToolSpecifications(Iterable<ToolSpecification> toolSpecifications) {
        int tokenCount = 16;
        for (ToolSpecification toolSpecification : toolSpecifications) {
            tokenCount += 6;
            tokenCount += this.estimateTokenCountInText(toolSpecification.name());
            if (toolSpecification.description() != null) {
                tokenCount += 2;
                tokenCount += this.estimateTokenCountInText(toolSpecification.description());
            }
            tokenCount += this.estimateTokenCountInToolParameters(toolSpecification.parameters());
        }
        return tokenCount;
    }

    private int estimateTokenCountInToolParameters(JsonObjectSchema parameters) {
        if (parameters == null) {
            return 0;
        }
        int tokenCount = 3;
        Map properties = parameters.properties();
        if (this.isOneOfLatestModels()) {
            tokenCount += properties.size() - 1;
        }
        for (String property : properties.keySet()) {
            tokenCount = this.isOneOfLatestModels() ? (tokenCount += 2) : (tokenCount += 3);
            tokenCount += this.estimateTokenCountInText(property);
            JsonSchemaElement element = (JsonSchemaElement)properties.get(property);
            if (element instanceof JsonArraySchema && this.isOneOfLatestModels()) {
                ++tokenCount;
                continue;
            }
            if (element instanceof JsonBooleanSchema || element instanceof JsonIntegerSchema || element instanceof JsonNumberSchema || element instanceof JsonStringSchema) {
                tokenCount += 2;
                String value = element instanceof JsonBooleanSchema ? ((JsonBooleanSchema)element).description() : (element instanceof JsonIntegerSchema ? ((JsonIntegerSchema)element).description() : (element instanceof JsonNumberSchema ? ((JsonNumberSchema)element).description() : ((JsonStringSchema)element).description()));
                tokenCount += this.estimateTokenCountInText(value);
                if (!this.isOneOfLatestModels() || !parameters.required().contains(property)) continue;
                ++tokenCount;
                continue;
            }
            if (!(element instanceof JsonEnumSchema)) continue;
            tokenCount = this.isOneOfLatestModels() ? (tokenCount -= 2) : (tokenCount -= 3);
            for (String value : ((JsonEnumSchema)element).enumValues()) {
                tokenCount += 3;
                tokenCount += this.estimateTokenCountInText(value);
            }
        }
        return tokenCount;
    }

    public int estimateTokenCountInForcefulToolSpecification(ToolSpecification toolSpecification) {
        int tokenCount = this.estimateTokenCountInToolSpecifications(Collections.singletonList(toolSpecification));
        tokenCount += 4;
        tokenCount += this.estimateTokenCountInText(toolSpecification.name());
        if (this.isOneOfLatestModels()) {
            tokenCount += 3;
        }
        return tokenCount;
    }

    public List<Integer> encode(String text) {
        return this.encoding.orElseThrow(this.unknownModelException()).encodeOrdinary(text).boxed();
    }

    public List<Integer> encode(String text, int maxTokensToEncode) {
        return this.encoding.orElseThrow(this.unknownModelException()).encodeOrdinary(text, maxTokensToEncode).getTokens().boxed();
    }

    public String decode(List<Integer> tokens) {
        IntArrayList intArrayList = new IntArrayList();
        for (Integer token : tokens) {
            intArrayList.add(token.intValue());
        }
        return this.encoding.orElseThrow(this.unknownModelException()).decode(intArrayList);
    }

    private Supplier<IllegalArgumentException> unknownModelException() {
        return () -> Exceptions.illegalArgument((String)"Model '%s' is unknown to jtokkit", (Object[])new Object[]{this.modelName});
    }

    public int estimateTokenCountInToolExecutionRequests(Iterable<ToolExecutionRequest> toolExecutionRequests) {
        int tokenCount = 0;
        int toolsCount = 0;
        int toolsWithArgumentsCount = 0;
        int toolsWithoutArgumentsCount = 0;
        int totalArgumentsCount = 0;
        for (ToolExecutionRequest toolExecutionRequest : toolExecutionRequests) {
            tokenCount += 4;
            tokenCount += this.estimateTokenCountInText(toolExecutionRequest.name());
            tokenCount += this.estimateTokenCountInText(toolExecutionRequest.arguments());
            int argumentCount = AzureOpenAiTokenizer.countArguments(toolExecutionRequest.arguments());
            if (argumentCount == 0) {
                ++toolsWithoutArgumentsCount;
            } else {
                ++toolsWithArgumentsCount;
            }
            totalArgumentsCount += argumentCount;
            ++toolsCount;
        }
        if (this.modelName.equals(AzureOpenAiChatModelName.GPT_3_5_TURBO_1106.toString()) || this.isOneOfLatestGpt4Models()) {
            tokenCount += 16;
            tokenCount += 3 * toolsWithoutArgumentsCount;
            tokenCount += toolsCount;
            if (totalArgumentsCount > 0) {
                --tokenCount;
                tokenCount -= 2 * totalArgumentsCount;
                tokenCount += 2 * toolsWithArgumentsCount;
                tokenCount += toolsCount;
            }
        }
        if (this.modelName.equals(AzureOpenAiChatModelName.GPT_4_1106_PREVIEW.toString())) {
            tokenCount += 3;
            if (toolsCount > 1) {
                tokenCount += 18;
                tokenCount += 15 * toolsCount;
                tokenCount += totalArgumentsCount;
                tokenCount -= 3 * toolsWithoutArgumentsCount;
            }
        }
        return tokenCount;
    }

    public int estimateTokenCountInForcefulToolExecutionRequest(ToolExecutionRequest toolExecutionRequest) {
        if (this.isOneOfLatestGpt4Models()) {
            int argumentsCount = AzureOpenAiTokenizer.countArguments(toolExecutionRequest.arguments());
            if (argumentsCount == 0) {
                return 1;
            }
            return this.estimateTokenCountInText(toolExecutionRequest.arguments());
        }
        int tokenCount = this.estimateTokenCountInToolExecutionRequests(Collections.singletonList(toolExecutionRequest));
        tokenCount -= 4;
        tokenCount -= this.estimateTokenCountInText(toolExecutionRequest.name());
        if (this.modelName.equals(AzureOpenAiChatModelName.GPT_3_5_TURBO_1106.toString())) {
            int argumentsCount = AzureOpenAiTokenizer.countArguments(toolExecutionRequest.arguments());
            if (argumentsCount == 0) {
                return 1;
            }
            tokenCount -= 19;
            tokenCount += 2 * argumentsCount;
        }
        return tokenCount;
    }

    static int countArguments(String arguments) {
        Map argumentsMap;
        if (Utils.isNullOrBlank((String)arguments)) {
            return 0;
        }
        try {
            argumentsMap = (Map)OBJECT_MAPPER.readValue(arguments, Map.class);
        }
        catch (JsonProcessingException e) {
            throw new RuntimeException(e);
        }
        return argumentsMap.size();
    }

    private boolean isOneOfLatestModels() {
        return this.isOneOfLatestGpt3Models() || this.isOneOfLatestGpt4Models();
    }

    private boolean isOneOfLatestGpt3Models() {
        return this.modelName.equals(AzureOpenAiChatModelName.GPT_3_5_TURBO_1106.toString()) || this.modelName.equals(AzureOpenAiChatModelName.GPT_3_5_TURBO.toString());
    }

    private boolean isOneOfLatestGpt4Models() {
        return this.modelName.equals(AzureOpenAiChatModelName.GPT_4_TURBO.toString()) || this.modelName.equals(AzureOpenAiChatModelName.GPT_4_1106_PREVIEW.toString()) || this.modelName.equals(AzureOpenAiChatModelName.GPT_4_0125_PREVIEW.toString());
    }
}

