1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
|
# 线段树的节点类
class SegTreeNode:
def __init__(self, val=[0, 1]):
self.left = -1 # 区间左边界
self.right = -1 # 区间右边界
self.val = val # 节点值(区间值)
# 线段树类
class SegmentTree:
# 初始化线段树接口
def __init__(self, size):
self.size = size
self.tree = [SegTreeNode() for _ in range(4 * self.size)] # 维护 SegTreeNode 数组
if self.size > 0:
self.__build(0, 0, self.size - 1)
# 单点更新接口:将 nums[i] 更改为 val
def update_point(self, i, val):
self.__update_point(i, val, 0)
# 区间查询接口:查询区间为 [q_left, q_right] 的区间值
def query_interval(self, q_left, q_right):
return self.__query_interval(q_left, q_right, 0)
# 以下为内部实现方法
# 构建线段树实现方法:节点的存储下标为 index,节点的区间为 [left, right]
def __build(self, index, left, right):
self.tree[index].left = left
self.tree[index].right = right
if left == right: # 叶子节点,节点值为对应位置的元素值
self.tree[index].val = [0, 0]
return
mid = left + (right - left) // 2 # 左右节点划分点
left_index = index * 2 + 1 # 左子节点的存储下标
right_index = index * 2 + 2 # 右子节点的存储下标
self.__build(left_index, left, mid) # 递归创建左子树
self.__build(right_index, mid + 1, right) # 递归创建右子树
self.tree[index].val = self.merge(self.tree[left_index].val, self.tree[right_index].val) # 向上更新节点的区间值
# 单点更新实现方法:将 nums[i] 更改为 val,节点的存储下标为 index
def __update_point(self, i, val, index):
left = self.tree[index].left
right = self.tree[index].right
if left == i and right == i:
self.tree[index].val = self.merge(self.tree[index].val, val)
return
mid = left + (right - left) // 2 # 左右节点划分点
left_index = index * 2 + 1 # 左子节点的存储下标
right_index = index * 2 + 2 # 右子节点的存储下标
if i <= mid: # 在左子树中更新节点值
self.__update_point(i, val, left_index)
else: # 在右子树中更新节点值
self.__update_point(i, val, right_index)
self.tree[index].val = self.merge(self.tree[left_index].val, self.tree[right_index].val) # 向上更新节点的区间值
# 区间查询实现方法:在线段树中搜索区间为 [q_left, q_right] 的区间值
def __query_interval(self, q_left, q_right, index):
left = self.tree[index].left
right = self.tree[index].right
if left >= q_left and right <= q_right: # 节点所在区间被 [q_left, q_right] 所覆盖
return self.tree[index].val # 直接返回节点值
if right < q_left or left > q_right: # 节点所在区间与 [q_left, q_right] 无关
return [0, 0]
mid = left + (right - left) // 2 # 左右节点划分点
left_index = index * 2 + 1 # 左子节点的存储下标
right_index = index * 2 + 2 # 右子节点的存储下标
res_left = [0, 0]
res_right = [0, 0]
if q_left <= mid: # 在左子树中查询
res_left = self.__query_interval(q_left, q_right, left_index)
if q_right > mid: # 在右子树中查询
res_right = self.__query_interval(q_left, q_right, right_index)
# 返回合并结果
return self.merge(res_left, res_right)
# 向上合并实现方法
def merge(self, val1, val2):
val = [0, 0]
if val1[0] == val2[0]: # 递增子序列长度一致,则合并后最长递增子序列个数为之前两者之和
val = [val1[0], val1[1] + val2[1]]
elif val1[0] < val2[0]: # 如果递增子序列长度不一致,则合并后最长递增子序列个数取较长一方的个数
val = [val2[0], val2[1]]
else:
val = [val1[0], val1[1]]
return val
class Solution:
def findNumberOfLIS(self, nums: List[int]) -> int:
# 离散化处理
num_dict = dict()
nums_sort = sorted(nums)
for i in range(len(nums_sort)):
num_dict[nums_sort[i]] = i
# 构造线段树
self.STree = SegmentTree(len(nums_sort))
for num in nums:
index = num_dict[num]
# 查询 [0, nums[index - 1]] 区间上以 nums[index - 1] 结尾的子序列所能达到的最长递增子序列长度和对应数量
val = self.STree.query_interval(0, index - 1)
# 如果当前最长递增子序列长度为 0,则加入 num 之后最长递增子序列长度为 1,且数量为 1
# 如果当前最长递增子序列长度不为 0,则加入 num 之后最长递增子序列长度 +1,但数量不变
if val[0] == 0:
val = [1, 1]
else:
val = [val[0] + 1, val[1]]
self.STree.update_point(index, val)
return self.STree.query_interval(0, len(nums_sort) - 1)[1]
|