/*
 * Copyright (c) Atmosphère-NX
 *
 * This program is free software; you can redistribute it and/or modify it
 * under the terms and conditions of the GNU General Public License,
 * version 2, as published by the Free Software Foundation.
 *
 * This program is distributed in the hope it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
 * more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */
#pragma once
#include <stratosphere/fs/common/fs_dbm_rom_types.hpp>
#include <stratosphere/fs/fs_substorage.hpp>

namespace ams::fs {

    /* ACCURATE_TO_VERSION: 14.3.0.0 */
    template<typename KeyType, typename ValueType, size_t MaxAuxiliarySize>
    class KeyValueRomStorageTemplate {
        public:
            using Key         = KeyType;
            using Value       = ValueType;
            using Position    = u32;
            using BucketIndex = s64;

            using StorageSizeType = u32;

            struct FindIndex {
                BucketIndex ind;
                Position pos;
            };
            static_assert(util::is_pod<FindIndex>::value);
        private:
            static constexpr inline Position InvalidPosition = ~Position();

            struct Element {
                Key key;
                Value value;
                Position next;
                StorageSizeType size;
            };
            static_assert(util::is_pod<Element>::value);
        private:
            s64 m_bucket_count;
            SubStorage m_bucket_storage;
            SubStorage m_kv_storage;
            s64 m_total_entry_size;
            u32 m_entry_count;
        public:
            static constexpr s64 QueryBucketStorageSize(s64 num) {
                return num * sizeof(Position);
            }

            static constexpr s64 QueryBucketCount(StorageSizeType size) {
                return size / sizeof(Position);
            }

            static constexpr size_t QueryEntrySize(StorageSizeType aux_size) {
                return util::AlignUp<size_t>(sizeof(Element) + aux_size, alignof(Element));
            }

            static Result Format(SubStorage bucket, StorageSizeType count) {
                const Position pos = InvalidPosition;
                for (auto i = 0u; i < count; i++) {
                    R_TRY(bucket.Write(i * sizeof(pos), std::addressof(pos), sizeof(pos)));
                }
                R_SUCCEED();
            }
        public:
            constexpr KeyValueRomStorageTemplate() : m_bucket_count(), m_bucket_storage(), m_kv_storage(), m_total_entry_size(), m_entry_count() { /* ... */ }

            Result Initialize(const SubStorage &bucket, s64 count, const SubStorage &kv) {
                AMS_ASSERT(count > 0);
                m_bucket_storage = bucket;
                m_bucket_count   = count;
                m_kv_storage     = kv;
                R_SUCCEED();
            }

            void Finalize() {
                m_bucket_storage = SubStorage();
                m_bucket_count   = 0;
                m_kv_storage     = SubStorage();
            }

            s64 GetTotalEntrySize() const {
                return m_total_entry_size;
            }
        protected:
            Result AddInternal(Position *out, const Key &key, u32 hash_key, const void *aux, size_t aux_size, const Value &value) {
                AMS_ASSERT(out != nullptr);
                AMS_ASSERT(aux != nullptr || aux_size == 0);
                AMS_ASSERT(m_bucket_count > 0);

                {
                    Position pos, prev_pos;
                    Element elem;

                    const Result find_res = this->FindInternal(std::addressof(pos), std::addressof(prev_pos), std::addressof(elem), key, hash_key, aux, aux_size);
                    R_UNLESS(R_FAILED(find_res),                           fs::ResultDbmAlreadyExists());
                    R_UNLESS(fs::ResultDbmKeyNotFound::Includes(find_res), find_res);
                }

                Position pos;
                R_TRY(this->AllocateEntry(std::addressof(pos), static_cast<StorageSizeType>(aux_size)));

                Position next_pos;
                R_TRY(this->LinkEntry(std::addressof(next_pos), pos, hash_key));

                const Element elem = { key, value, next_pos, static_cast<StorageSizeType>(aux_size) };
                R_TRY(this->WriteKeyValue(std::addressof(elem), pos, aux, aux_size));

                *out = pos;
                m_entry_count++;

                R_SUCCEED();
            }

            Result GetInternal(Position *out_pos, Value *out_val, const Key &key, u32 hash_key, const void *aux, size_t aux_size) {
                AMS_ASSERT(out_pos != nullptr);
                AMS_ASSERT(out_val != nullptr);
                AMS_ASSERT(aux     != nullptr);

                Position pos, prev_pos;
                Element elem;
                R_TRY(this->FindInternal(std::addressof(pos), std::addressof(prev_pos), std::addressof(elem), key, hash_key, aux, aux_size));

                *out_pos = pos;
                *out_val = elem.value;
                R_SUCCEED();
            }

            Result GetByPosition(Key *out_key, Value *out_val, Position pos) {
                AMS_ASSERT(out_key != nullptr);
                AMS_ASSERT(out_val != nullptr);

                Element elem;
                R_TRY(this->ReadKeyValue(std::addressof(elem), pos));

                *out_key = elem.key;
                *out_val = elem.value;
                R_SUCCEED();
            }

            Result GetByPosition(Key *out_key, Value *out_val, void *out_aux, size_t *out_aux_size, Position pos) {
                AMS_ASSERT(out_key != nullptr);
                AMS_ASSERT(out_val != nullptr);
                AMS_ASSERT(out_aux != nullptr);
                AMS_ASSERT(out_aux_size != nullptr);

                Element elem;
                R_TRY(this->ReadKeyValue(std::addressof(elem), out_aux, out_aux_size, pos));

                *out_key = elem.key;
                *out_val = elem.value;
                R_SUCCEED();
            }

            Result SetByPosition(Position pos, const Value &value) {
                Element elem;
                R_TRY(this->ReadKeyValue(std::addressof(elem), pos));
                elem.value = value;
                R_RETURN(this->WriteKeyValue(std::addressof(elem), pos, nullptr, 0));
            }
        private:
            BucketIndex HashToBucket(u32 hash_key) const {
                return hash_key % m_bucket_count;
            }

            Result FindInternal(Position *out_pos, Position *out_prev, Element *out_elem, const Key &key, u32 hash_key, const void *aux, size_t aux_size) {
                AMS_ASSERT(out_pos != nullptr);
                AMS_ASSERT(out_prev != nullptr);
                AMS_ASSERT(out_elem != nullptr);
                AMS_ASSERT(aux != nullptr || aux_size == 0);
                AMS_ASSERT(m_bucket_count > 0);

                *out_pos = 0;
                *out_prev = 0;

                const BucketIndex ind = HashToBucket(hash_key);

                Position cur;
                R_TRY(this->ReadBucket(std::addressof(cur), ind));

                s64 kv_size;
                R_TRY(m_kv_storage.GetSize(std::addressof(kv_size)));
                AMS_ASSERT(cur == InvalidPosition || cur < kv_size);

                R_UNLESS(cur != InvalidPosition, fs::ResultDbmKeyNotFound());

                auto buf = ::ams::fs::impl::MakeUnique<u8[]>(MaxAuxiliarySize);
                R_UNLESS(buf != nullptr, fs::ResultAllocationMemoryFailedMakeUnique());

                while (true) {
                    size_t cur_aux_size;
                    R_TRY(this->ReadKeyValue(out_elem, buf.get(), std::addressof(cur_aux_size), cur));

                    if (key.IsEqual(out_elem->key, aux, aux_size, buf.get(), cur_aux_size)) {
                        *out_pos = cur;
                        R_SUCCEED();
                    }

                    *out_prev = cur;
                    cur = out_elem->next;
                    R_UNLESS(cur != InvalidPosition, fs::ResultDbmKeyNotFound());
                }
            }

            Result AllocateEntry(Position *out, StorageSizeType aux_size) {
                AMS_ASSERT(out != nullptr);

                s64 kv_size;
                R_TRY(m_kv_storage.GetSize(std::addressof(kv_size)));
                const size_t end_pos = m_total_entry_size + sizeof(Element) + static_cast<size_t>(aux_size);
                R_UNLESS(end_pos <= static_cast<size_t>(kv_size), fs::ResultDbmKeyFull());

                *out = static_cast<Position>(m_total_entry_size);

                m_total_entry_size = util::AlignUp<s64>(static_cast<s64>(end_pos), alignof(Position));
                R_SUCCEED();
            }

            Result LinkEntry(Position *out, Position pos, u32 hash_key) {
                AMS_ASSERT(out != nullptr);

                const BucketIndex ind = HashToBucket(hash_key);

                Position next;
                R_TRY(this->ReadBucket(std::addressof(next), ind));

                s64 kv_size;
                R_TRY(m_kv_storage.GetSize(std::addressof(kv_size)));
                AMS_ASSERT(next == InvalidPosition || next < kv_size);

                R_TRY(this->WriteBucket(pos, ind));

                *out = next;
                R_SUCCEED();
            }

            Result ReadBucket(Position *out, BucketIndex ind) {
                AMS_ASSERT(out != nullptr);
                AMS_ASSERT(ind < m_bucket_count);

                const s64 offset = ind * sizeof(Position);
                R_RETURN(m_bucket_storage.Read(offset, out, sizeof(*out)));
            }

            Result WriteBucket(Position pos, BucketIndex ind) {
                AMS_ASSERT(ind < m_bucket_count);

                const s64 offset = ind * sizeof(Position);
                R_RETURN(m_bucket_storage.Write(offset, std::addressof(pos), sizeof(pos)));
            }

            Result ReadKeyValue(Element *out, Position pos) {
                AMS_ASSERT(out != nullptr);

                s64 kv_size;
                R_TRY(m_kv_storage.GetSize(std::addressof(kv_size)));
                AMS_ASSERT(pos < kv_size);

                R_RETURN(m_kv_storage.Read(pos, out, sizeof(*out)));
            }

            Result ReadKeyValue(Element *out, void *out_aux, size_t *out_aux_size, Position pos) {
                AMS_ASSERT(out != nullptr);
                AMS_ASSERT(out_aux != nullptr);
                AMS_ASSERT(out_aux_size != nullptr);

                R_TRY(this->ReadKeyValue(out, pos));

                *out_aux_size = out->size;
                if (out->size > 0) {
                    R_TRY(m_kv_storage.Read(pos + sizeof(*out), out_aux, out->size));
                }

                R_SUCCEED();
            }

            Result WriteKeyValue(const Element *elem, Position pos, const void *aux, size_t aux_size) {
                AMS_ASSERT(elem != nullptr);
                AMS_ASSERT(aux != nullptr);

                s64 kv_size;
                R_TRY(m_kv_storage.GetSize(std::addressof(kv_size)));
                AMS_ASSERT(pos < kv_size);

                R_TRY(m_kv_storage.Write(pos, elem, sizeof(*elem)));

                if (aux != nullptr && aux_size > 0) {
                    R_TRY(m_kv_storage.Write(pos + sizeof(*elem), aux, aux_size));
                }

                R_SUCCEED();
            }
    };

}
