Segment Tree Implementation Code (C++)

Updated:

theory link 1

theory link 2

Read

​ long long 데이터를 기준으로 만든 클래스이다. 기본 토대를 바탕으로 본인이 필요한 데이터로 치환해서 사용하거나 template으로 변경해서 사용하면 된다.

​ 알고리즘 공부용으로 검증이 완벽히 되지 않은 코드이니 본인이 사용할 때 완벽하게 검증을 해보는 것을 추천한다. (BOJ 10999 문제를 LPSegmentTree.h 로 해결했다.)

SegmentTree.h

class SegmentTree
{
    private:
        std::vector<long long> tree;
        int treeSize;
    public:
        void Input(int treeSize);
        void Build();
        void Push();
        void ModifyElement(int position, long long value);
        void ModifyInterval(int left, int right, long long value);
        long long QueryElement(int position);
        long long QueryInterval(int left, int right);      

        void Print(); 
};

void SegmentTree::Input(int treeSize)
{
    this->treeSize = treeSize;
    this->tree.resize(this->treeSize * 2);

    for (int i = 0; i < this->treeSize; i++)
    {
        std::cin >> this->tree[this->treeSize + i];
    }
}

void SegmentTree::Build()
{
    for (int i = this->treeSize - 1; i > 0; i--)
    {
        this->tree[i] = this->tree[i<<1] + this->tree[i<<1|1];
    }
}

void SegmentTree::Push()
{
    for (int i = 1; i < this->treeSize; i++)
    {
        this->tree[i<<1] += tree[i];
        this->tree[i<<1|1] += tree[i];
        this->tree[i] = 0;
    }
}

void SegmentTree::ModifyElement(int position, long long value)
{
    for (this->tree[position += this->treeSize] = value; position > 1; position >>= 1)
    {
        this->tree[position>>1] = this->tree[position] + this->tree[position^1];
    }
}

void SegmentTree::ModifyInterval(int left, int right, long long value)
{
    for (left += this->treeSize, right += this->treeSize; left < right; left >>= 1, right >>= 1)
    {
        if (left & 1)
        {
            this->tree[left++] += value;
        }
        if (right & 1)
        {
            tree[--right] += value;
        }
    }
}

long long SegmentTree::QueryElement(int position)
{
    long long result = 0;

    for (position += this->treeSize; position > 0; position >>= 1)
    {
        result += this->tree[position];
    }

    return result;
}

// sum on interval [left, right)
long long SegmentTree::QueryInterval(int left, int right)
{
    long long result = 0;

    for (left += this->treeSize, right += this->treeSize; left < right; left >>= 1, right >>= 1)
    {
        if (left & 1)
        {
            result += this->tree[left++];
        }
        if (right & 1)
        {
            result += this->tree[--right];
        }
    }

    return result;
}

void SegmentTree::Print()
{
    for (int i = 1; i < this->tree.size(); i++)
    {
        std::cout << this->tree[i] << " ";
    }
    std::cout << "\n";
}

LPSegmentTree.h (Lazy Propagation Segment Tree)

/*
    Lazy Propagation Segment Tree
*/
class LPSegmentTree
{
private:
    std::vector<long long> tree;
    std::vector<long long> d;
    int treeSize;
    int treeHeight;

public:
    void Input(int treeSize);
    void BuildInterval(int left, int right);
    void Calculate(int position, int k);
    void Apply(int position, long long value, int k);
    void PushInterval(int left, int right);
    void ModifyInterval(int left, int right, long long value);
    long long QueryInterval(int left, int right);

    void Print();
};

void LPSegmentTree::Input(int treeSize)
{
    this->treeSize = treeSize;
    this->treeHeight = sizeof(int) * 8 - __builtin_clz(treeSize);
    this->tree.resize(this->treeSize * 2);
    this->d.resize(this->treeSize);

    for (int i = 0; i < this->treeSize; i++)
    {
        std::cin >> this->tree[this->treeSize + i];
    }
}

void LPSegmentTree::BuildInterval(int left, int right)
{
    int k = 2;
    for (left += this->treeSize, right += this->treeSize - 1; left > 1; k <<= 1)
    {
        left >>= 1;
        right >>= 1;
        for (int i = right; i >= left; i--)
        {
            Calculate(i, k);
        }
    }
}

void LPSegmentTree::Calculate(int position, int k)
{
    if (this->d[position] == 0)
    {
        this->tree[position] = this->tree[position << 1] + this->tree[position << 1 | 1];
    }
    else
    {
        this->tree[position] += this->d[position] * k;
    }
}

void LPSegmentTree::Apply(int position, long long value, int k)
{
    this->tree[position] += value * k;
    if (position < this->treeSize)
    {
        this->d[position] += value;
    }
}

void LPSegmentTree::PushInterval(int left, int right)
{
    int s = this->treeHeight;
    int k = 1 << (this->treeHeight - 1);
    for (left += this->treeSize, right += this->treeSize - 1; s > 0; s--, k >>= 1)
    {
        for (int i = left >> s; i <= right >> s; i++)
        {
            if (this->d[i] == 0)
                continue;

            Apply(i << 1, this->d[i], k);
            Apply(i << 1 | 1, this->d[i], k);
            this->d[i] = 0;
        }
    }
}

void LPSegmentTree::ModifyInterval(int left, int right, long long value)
{
    if (value == 0)
        return;

    PushInterval(left, left + 1);
    PushInterval(right - 1, right);

    bool cLeft = false, cRight = false;
    int k = 1;
    for (left += this->treeSize, right += this->treeSize; left < right; left >>= 1, right >>= 1, k <<= 1)
    {
        if (cLeft == true)
        {
            Calculate(left - 1, k);
        }

        if (cRight == true)
        {
            Calculate(right, k);
        }

        if (left & 1)
        {
            Apply(left++, value, k);
            cLeft = true;
        }

        if (right & 1)
        {
            Apply(--right, value, k);
            cRight = true;
        }
    }

    for (--left; right > 0; left >>= 1, right >>= 1, k <<= 1)
    {
        if (cLeft == true)
        {
            Calculate(left, k);
        }

        if (cRight == true && (cLeft == false || left != right))
        {
            Calculate(right, k);
        }
    }
}

long long LPSegmentTree::QueryInterval(int left, int right)
{
    PushInterval(left, left + 1);
    PushInterval(right - 1, right);

    long long result = 0;
    for (left += this->treeSize, right += this->treeSize; left < right; left >>= 1, right >>= 1)
    {
        if (left & 1)
        {
            result += this->tree[left++];
        }
        if (right & 1)
        {
            result += this->tree[--right];
        }
    }

    return result;
}

void LPSegmentTree::Print()
{
    for (int i = 1; i < this->tree.size(); i++)
    {
        std::cout << this->tree[i] << " ";
    }
    std::cout << "\n";

    for (int i = 1; i < this->d.size(); i++)
    {
        std::cout << this->d[i] << " ";
    }
    std::cout << "\n";
}

Leave a comment