/*
 * Copyright (C) 2025 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "src/protovm/test/sample_packets.h"
#include "test/gtest_and_gmock.h"

#include "src/protovm/test/protos/incremental_trace.pb.h"
#include "src/protovm/test/sample_programs.h"
#include "src/protovm/test/utils.h"
#include "src/protovm/vm.h"

namespace perfetto {
namespace protovm {
namespace test {

class VmTest : public ::testing::Test {
 protected:
  static constexpr size_t MEMORY_LIMIT_BYTES = 10 * 1024 * 1024;

  std::string InitialIncrementalState() const {
    protos::TraceEntry state{};
    auto* element0 = state.add_elements();
    element0->set_id(0);
    element0->set_value(10);
    auto* element1 = state.add_elements();
    element1->set_id(1);
    element1->set_value(11);
    return state.SerializeAsString();
  }

  std::string SerializeIncrementalStateAsString(const Vm& vm) const {
    protozero::HeapBuffered<protozero::Message> proto;
    vm.SerializeIncrementalState(proto.get());
    return proto.SerializeAsString();
  }
};

TEST_F(VmTest, NoPatch) {
  auto program =
      SamplePrograms::IncrementalTraceInstructions().SerializeAsString();
  Vm vm{AsConstBytes(program), MEMORY_LIMIT_BYTES};
  ASSERT_TRUE(SerializeIncrementalStateAsString(vm).empty());
}

TEST_F(VmTest, ConstructionWithInitialIncrementalState) {
  auto program =
      SamplePrograms::IncrementalTraceInstructions().SerializeAsString();

  Vm vm{AsConstBytes(program), MEMORY_LIMIT_BYTES,
        AsConstBytes(InitialIncrementalState())};

  protos::TraceEntry state{};
  state.ParseFromString(SerializeIncrementalStateAsString(vm));
  ASSERT_EQ(state.elements_size(), 2);
  ASSERT_EQ(state.elements(0).id(), 0);
  ASSERT_EQ(state.elements(0).value(), 10);
  ASSERT_EQ(state.elements(1).id(), 1);
  ASSERT_EQ(state.elements(1).value(), 11);
}

TEST_F(VmTest, ApplyPatch_DelOperation) {
  auto program =
      SamplePrograms::IncrementalTraceInstructions().SerializeAsString();
  Vm vm{AsConstBytes(program), MEMORY_LIMIT_BYTES};

  auto patch = SamplePackets::PatchWithInitialState().SerializeAsString();
  vm.ApplyPatch(AsConstBytes(patch));

  patch = SamplePackets::PatchWithDelOperation().SerializeAsString();
  vm.ApplyPatch(AsConstBytes(patch));

  protos::TraceEntry state{};
  state.ParseFromString(SerializeIncrementalStateAsString(vm));
  ASSERT_EQ(state.elements_size(), 1);
  ASSERT_EQ(state.elements(0).id(), 1);
  ASSERT_EQ(state.elements(0).value(), 11);
}

TEST_F(VmTest, ApplyPatch_MergeOperation) {
  auto program =
      SamplePrograms::IncrementalTraceInstructions().SerializeAsString();
  Vm vm{AsConstBytes(program), MEMORY_LIMIT_BYTES};

  // patch #1
  {
    auto patch = SamplePackets::PatchWithMergeOperation1().SerializeAsString();
    vm.ApplyPatch(AsConstBytes(patch));

    protos::TraceEntry state{};
    state.ParseFromString(SerializeIncrementalStateAsString(vm));
    ASSERT_EQ(state.elements_size(), 1);
    ASSERT_EQ(state.elements(0).id(), 0);
    ASSERT_EQ(state.elements(0).value(), 10);
  }

  // patch #2
  {
    auto patch = SamplePackets::PatchWithMergeOperation2().SerializeAsString();
    vm.ApplyPatch(AsConstBytes(patch));

    protos::TraceEntry state{};
    state.ParseFromString(SerializeIncrementalStateAsString(vm));
    ASSERT_EQ(state.elements_size(), 2);
    ASSERT_EQ(state.elements(0).id(), 0);
    ASSERT_EQ(state.elements(0).value(), 100);
    ASSERT_EQ(state.elements(1).id(), 1);
    ASSERT_EQ(state.elements(1).value(), 101);
  }
}

TEST_F(VmTest, ApplyPatch_SetOperation) {
  auto program =
      SamplePrograms::IncrementalTraceInstructions().SerializeAsString();
  Vm vm{AsConstBytes(program), MEMORY_LIMIT_BYTES};

  // empty
  {
    protos::TraceEntry state{};
    state.ParseFromString(SerializeIncrementalStateAsString(vm));
    ASSERT_EQ(state.elements_size(), 0);
  }

  // patch #1
  {
    auto patch = SamplePackets::PatchWithInitialState().SerializeAsString();
    vm.ApplyPatch(AsConstBytes(patch));

    protos::TraceEntry state{};
    state.ParseFromString(SerializeIncrementalStateAsString(vm));
    ASSERT_EQ(state.elements_size(), 2);
    ASSERT_EQ(state.elements(0).id(), 0);
    ASSERT_EQ(state.elements(0).value(), 10);
    ASSERT_EQ(state.elements(1).id(), 1);
    ASSERT_EQ(state.elements(1).value(), 11);
  }

  // patch #2
  {
    auto patch = SamplePackets::PatchWithSetOperation().SerializeAsString();
    vm.ApplyPatch(AsConstBytes(patch));

    protos::TraceEntry state{};
    state.ParseFromString(SerializeIncrementalStateAsString(vm));
    ASSERT_EQ(state.elements_size(), 2);
    ASSERT_EQ(state.elements(0).id(), 0);
    ASSERT_FALSE(state.elements(0).has_value());
    ASSERT_EQ(state.elements(1).id(), 1);
    ASSERT_EQ(state.elements(1).value(), 101);
  }
}

TEST_F(VmTest, ApplyPatch_ErrorHandling) {
  auto program =
      SamplePrograms::IncrementalTraceInstructions().SerializeAsString();
  Vm vm{AsConstBytes(program), MEMORY_LIMIT_BYTES};

  auto patch = SamplePackets::PatchInconsistentWithIncrementalTraceProgram();
  auto status = vm.ApplyPatch(AsConstBytes(patch));
  ASSERT_TRUE(status.IsAbort());

  const auto& stacktrace = status.stacktrace();
  ASSERT_FALSE(stacktrace.empty());
  ASSERT_NE(stacktrace.front().find(
                "Attempted to access length-delimited field as a scalar"),
            std::string::npos);
}

TEST_F(VmTest, CloneReadOnly) {
  auto program =
      SamplePrograms::IncrementalTraceInstructions().SerializeAsString();
  Vm vm{AsConstBytes(program), MEMORY_LIMIT_BYTES};

  auto patch = SamplePackets::PatchWithInitialState().SerializeAsString();
  vm.ApplyPatch(AsConstBytes(patch));

  std::unique_ptr<Vm> cloned_vm = vm.CloneReadOnly();

  // Check read-only VM doesn't accept patches
  ASSERT_TRUE(cloned_vm->ApplyPatch(AsConstBytes(patch)).IsAbort());

  // Check cloned incremental state
  protos::TraceEntry cloned_state{};
  cloned_state.ParseFromString(SerializeIncrementalStateAsString(*cloned_vm));
  ASSERT_EQ(cloned_state.elements_size(), 2);
  ASSERT_EQ(cloned_state.elements(0).id(), 0);
  ASSERT_EQ(cloned_state.elements(0).value(), 10);
  ASSERT_EQ(cloned_state.elements(1).id(), 1);
  ASSERT_EQ(cloned_state.elements(1).value(), 11);
}

TEST_F(VmTest, GetMemoryUsage) {
  auto program =
      SamplePrograms::IncrementalTraceInstructions().SerializeAsString();
  Vm vm{AsConstBytes(program), MEMORY_LIMIT_BYTES};

  // Initial memory usage only accounts for the program size
  ASSERT_EQ(vm.GetMemoryUsageBytes(), program.size());
  ASSERT_EQ(vm.CloneReadOnly()->GetMemoryUsageBytes(), program.size());

  // Populating the incremental state increases memory usage
  auto patch = SamplePackets::PatchWithInitialState().SerializeAsString();
  vm.ApplyPatch(AsConstBytes(patch));
  ASSERT_GT(vm.GetMemoryUsageBytes(), program.size());
  ASSERT_GT(vm.CloneReadOnly()->GetMemoryUsageBytes(), program.size());
}

}  // namespace test
}  // namespace protovm
}  // namespace perfetto
