首先对 a、b、c排序 暴力超时
# coding=utf8 """ 给定三个整数数组 A = [A1, A2, ... AN], B = [B1, B2, ... BN], C = [C1, C2, ... CN], 请你统计有多少个三元组(i, j, k) 满足: 1. 1 <= i, j, k <= N 2. Ai < Bj < Ck 【样例输入 3 1 1 1 2 2 2 3 3 3 """ import sys N = int(input()) a = sorted(list(map(int, (input().split())))) b = sorted(list(map(int, (input().split())))) c = sorted(list(map(int, (input().split())))) if len(a) != len(b) or len(b) != len(c): print(None) # print(a) # print(b) # print(c) SUM = 0 la, lb, lc = N - 1, N - 1, N - 1 if a[0] >= b[lb] or b[0] >= c[lc]: # a数组中最小的数大于等于b中最大的数 o # b数组中最小的数大于等于c中最大的数 r print(0) sys.exit() # 退出程序 if c[0] > b[lb] and b[0] > a[la]: # b数组中最小的数大于a中最大的数 并 # c数组中最小的数大于b中最大的数 且 print(N*N*N) sys.exit() for i in a: for j in b: for k in c: if i < j < k: SUM += 1 print(SUM)二分找到max(j) 使得a[i] < b[j]
N = int(input()) a = sorted(list(map(int, (input().split())))) b = sorted(list(map(int, (input().split())))) c = sorted(list(map(int, (input().split())))) if len(a) != len(b) or len(b) != len(c): print(None) def search_lower_idx(num, nums): if num >= nums[-1]: return -1 if num < nums[0]: return 0 idx = len(nums)//2 while num > nums[idx]: idx = (len(nums) + idx)//2 while num < nums[idx]: idx //= 2 while num == nums[idx]: idx += 1 return idx SUM = 0 for la in range(N): # print('la', la) idxb = search_lower_idx(a[la], b) # print('idxb', idxb) if idxb == -1: continue for lb in range(idxb, N): # print(f'lb {lb}') idxc = search_lower_idx(b[lb], c) # print(f'idxc {idxc}') if idxc != -1: SUM += N - idxc # print(f'SUM {SUM}') # print(f'la {la} SUM {SUM}\n') print(SUM) """ 4 3 4 2 1 3 4 2 1 3 4 2 1 """二分的方法对于某些测试不通过 为什么呢 这是为什么呢?