[백준] 합이 0
문제 링크 : https://www.acmicpc.net/problem/3151
3151번: 합이 0
Elly는 예상치 못하게 프로그래밍 대회를 준비하는 학생들을 가르칠 위기에 처했다. 대회는 정확히 3명으로 구성된 팀만 참가가 가능하다. 그러나 그녀가 가르칠 학생들에게는 큰 문제가 있었다.
www.acmicpc.net
처음에는 3중 for문을 생각했으나 그렇게 풀면 O(n^3)이 나오니 당연히 안 될 거 같았다.
생각 중에 아래에 태그를 보니까 이렇게 나와 있었다.
전체 탐색을 하지만 이분 탐색으로 탐색을 하는 것과 두 포인터를 잡는 것을 보면서
먼저 입력된 배열을 sort 하고 -> 이분 탐색을 하는데 포인터를 두 개를 두거나 이분 탐색을 두 번 해서 합을 찾는 것으로 생각을 해봤다.
그래서 첫번째로 짜놓은 코드는 다음과 같다.
package baekjoon;
import java.util.Arrays;
import java.util.Scanner;
public class TotalZero3151 {
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int N = sc.nextInt();
int arr[] = new int[N];
for (int i = 0; i < N; i++) {
arr[i] = sc.nextInt();
}
Arrays.sort(arr);
// -6 -5 -4 -4 0 1 2 2 3 7
for (int i = 0; i < arr.length; i++) {
int tmp = arr[i];
int first = i+1;
int last = arr.length;
int mid = (first+last)/2;
while(first<last) {
}
}
}
}
다시 고민에 빠졌는데 포인터 두개?를 어떻게 둘지 였다. 일단 고민을 더 해봐야겠다...!
(5.16 작성)
알고리즘 힌트에서 투 포인터로 수를 찾는 방법을 생각해냈다.
일단 백준으로 구현한 코드는 다음과 같다.
import java.util.Arrays;
import java.util.Scanner;
public class Main {
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int N = sc.nextInt();
int arr[] = new int[N];
for (int i = 0; i < N; i++) {
arr[i] = sc.nextInt();
}
// 2 -5 2 3 -4 7 -4 0 1 6
Arrays.sort(arr);
// -5 -4 -4 0 1 2 2 3 6 7
int count = 0;
for (int i = 0; i < N; i++) {
int ans = 0;
ans += arr[i];
int left = i+1;
int right = N-1;
while(left<right) {
ans += arr[left] + arr[right];
if (ans < 0) {
left += 1;
} else if (ans > 0) {
right -= 1;
} else {
if(arr[left] == arr[right]) {
count += (left-right)*(left-right-1)/2;
} else {
int l_flag = 0;
int r_flag = 0;
while(arr[left+l_flag]==arr[left]) {
l_flag += 1;
}
while(arr[right+r_flag]==arr[right]) {
r_flag -=1;
}
left += l_flag;
right -= r_flag;
count += (l_flag) * (r_flag);
}
}
}
}
System.out.println(count);
}
}
써 본 경우를 생각해보면 다음과 같다.
for (int i = 0; i < N; i++) {
int ans = 0;
ans += arr[i];
int left = i+1;
int right = N-1;
}
arr[i]를 한 수 left와 right를 투 포인터로 arr[right], arr[left]로 세 수를 잡았다.
그래서 right < left 까지 while문을 돌아 범위로 좁혀가면서 값을 찾아내는 방식을 활용하였다.
작성한 while 문은 다음과 같다.
while(left<right) {
ans += arr[left] + arr[right];
if (ans < 0) {
left += 1;
} else if (ans > 0) {
right -= 1;
} else {
if(arr[left] == arr[right]) {
count += (left-right)*(left-right-1)/2;
} else {
int l_flag = 0;
int r_flag = 0;
while(arr[left+l_flag]==arr[left]) {
l_flag += 1;
}
while(arr[right+r_flag]==arr[right]) {
r_flag -=1;
}
left = l_flag;
right = r_flag;
count += (l_flag) * (r_flag);
}
}
}
여기서 신경써야 할 부분은 합쳐서 0이 나오는 부분이다.
0이 나오는 부분이 왜? 라고 생각할 수 있는데 중복된 수가 있는 배열이기 때문에 그렇다.
가령 -2(i) 1(right) 1 1 1(left) 이라는 수열에서는 1이 5개 이기 때문에 5개 중에서 2개를 선택하는 조합으로 count를 더해야 한다. n 개의 수(인덱스) 에서 2개를 선택하기 때문에 n*(n-1)/2 라고 할 수 있고, 여기서 적용을 하면 (left-right) * (left - right -1)/2 라고 할 수 있겠다.
또한 -5 (i)-4(right) -4 0 1(left) 1 이러한 수열이 있다면 -4가 2개이고 1이 2개 이기 때문에 각 수의 갯수에서 1개를 선택하는 조합 식을 사용하여 곱해야 한다. left 에서는 하나씩 인덱스를 더해가면서 동일한 부분을 살피고, right 부분에서는 인덱스를 하나씩 내려가는 식으로 살펴본다. 이를 l_flag과 r_flag로 정의하였다.
else {
if(arr[left] == arr[right]) {
count += (left-right)*(left-right-1)/2;
} else {
int l_flag = 0;
int r_flag = 0;
while(arr[left+l_flag]==arr[left]) {
l_flag += 1;
}
while(arr[right-r_flag]==arr[right]) {
r_flag +=1;
}
left = l_flag;
right = r_flag;
count += (l_flag) * (r_flag);
}
}
신경을 써서 다시 코드를 짜봤다.
최종적인 코드는 아래와 같고 이제 한번 돌려봐야겠다.
import java.util.Arrays;
import java.util.Scanner;
public class Main {
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int N = sc.nextInt();
int arr[] = new int[N];
for (int i = 0; i < N; i++) {
arr[i] = sc.nextInt();
}
// 2 -5 2 3 -4 7 -4 0 1 6
Arrays.sort(arr);
// -5 -4 -4 0 1 2 2 3 6 7
int count = 0;
for (int i = 0; i < N; i++) {
int ans = 0;
ans += arr[i];
int left = i+1;
int right = N-1;
while(left<right) {
ans += arr[left] + arr[right];
if (ans < 0) {
left += 1;
} else if (ans > 0) {
right -= 1;
} else {
if(arr[left] == arr[right]) {
count += (left-right)*(left-right-1)/2;
} else {
int l_flag = 0;
int r_flag = 0;
while(arr[left+l_flag]==arr[left]) {
l_flag += 1;
}
while(arr[right-r_flag]==arr[right]) {
r_flag +=1;
}
left = l_flag;
right = r_flag;
count += (l_flag) * (r_flag);
}
}
}
}
System.out.println(count);
}
}