알고리즘/백준

[백준] 합이 0

HYJJ 2022. 5. 13. 16:50

문제 링크 : https://www.acmicpc.net/problem/3151 

 

3151번: 합이 0

Elly는 예상치 못하게 프로그래밍 대회를 준비하는 학생들을 가르칠 위기에 처했다. 대회는 정확히 3명으로 구성된 팀만 참가가 가능하다. 그러나 그녀가 가르칠 학생들에게는 큰 문제가 있었다.

www.acmicpc.net

처음에는 3중 for문을 생각했으나 그렇게 풀면 O(n^3)이 나오니 당연히 안 될 거 같았다.

 

생각 중에 아래에 태그를 보니까 이렇게 나와 있었다. 

전체 탐색을 하지만 이분 탐색으로 탐색을 하는 것과 두 포인터를 잡는 것을 보면서

먼저 입력된 배열을 sort 하고 -> 이분 탐색을 하는데 포인터를 두 개를 두거나 이분 탐색을 두 번 해서 합을 찾는 것으로 생각을 해봤다. 

 

 

그래서 첫번째로 짜놓은 코드는 다음과 같다.  

package baekjoon;

import java.util.Arrays;
import java.util.Scanner;

public class TotalZero3151 {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int N = sc.nextInt();
        int arr[] = new int[N];
        for (int i = 0; i < N; i++) {
            arr[i] = sc.nextInt();
        }
        Arrays.sort(arr);
        // -6 -5 -4 -4 0 1 2 2 3 7
        for (int i = 0; i < arr.length; i++) {
            int tmp = arr[i];
            int first = i+1;
            int last = arr.length;
            int mid = (first+last)/2;
            while(first<last) {

            }
        }

    }
}

다시 고민에 빠졌는데 포인터 두개?를 어떻게 둘지 였다.  일단 고민을 더 해봐야겠다...!


(5.16 작성)

알고리즘 힌트에서 투 포인터로 수를 찾는 방법을 생각해냈다. 

일단 백준으로 구현한 코드는 다음과 같다. 

import java.util.Arrays;
import java.util.Scanner;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int N = sc.nextInt();
        int arr[] = new int[N];

        for (int i = 0; i < N; i++) {
            arr[i] = sc.nextInt();
        }

        // 2 -5 2 3 -4 7 -4 0 1 6

        Arrays.sort(arr);
        
        // -5 -4 -4 0 1 2 2 3 6 7

        int count = 0;

        for (int i = 0; i < N; i++) {
            int ans = 0;
            ans += arr[i];

            int left = i+1;
            int right = N-1;

            while(left<right) {
                ans += arr[left] + arr[right];
                if (ans < 0) {
                    left += 1;
                } else if (ans > 0) {
                    right -= 1;
                } else {
                   if(arr[left] == arr[right]) {
                       count += (left-right)*(left-right-1)/2;
                   } else {
                       int l_flag = 0;
                       int r_flag = 0;
                       while(arr[left+l_flag]==arr[left]) {
                           l_flag += 1; 
                       }
                       while(arr[right+r_flag]==arr[right]) {
                           r_flag -=1;
                       }
                       left += l_flag;
                       right -= r_flag;
                       count += (l_flag) * (r_flag);
                   }
                }
            }

        }
        System.out.println(count);
       }    
}

써 본 경우를 생각해보면 다음과 같다.

        for (int i = 0; i < N; i++) {
            int ans = 0;
            ans += arr[i];

            int left = i+1;
            int right = N-1;
        }

arr[i]를 한 수 left와 right를 투 포인터로 arr[right], arr[left]로 세 수를 잡았다.

그래서 right < left 까지 while문을 돌아 범위로 좁혀가면서 값을 찾아내는 방식을 활용하였다. 

 

작성한 while 문은 다음과 같다. 

while(left<right) {
    ans += arr[left] + arr[right];
    if (ans < 0) {
        left += 1;
    } else if (ans > 0) {
        right -= 1;
    } else {
       if(arr[left] == arr[right]) {
           count += (left-right)*(left-right-1)/2;
       } else {
           int l_flag = 0;
           int r_flag = 0;
           while(arr[left+l_flag]==arr[left]) {
               l_flag += 1; 
           }
           while(arr[right+r_flag]==arr[right]) {
               r_flag -=1;
           }
           left = l_flag;
           right = r_flag;
           count += (l_flag) * (r_flag);
       }
    }
}

여기서 신경써야 할 부분은 합쳐서 0이 나오는 부분이다.

 

0이 나오는 부분이 왜? 라고 생각할 수 있는데 중복된 수가 있는 배열이기 때문에 그렇다.

 

가령 -2(i) 1(right) 1 1 1(left) 이라는 수열에서는 1이 5개 이기 때문에 5개 중에서 2개를 선택하는 조합으로 count를 더해야 한다. n 개의 수(인덱스) 에서 2개를 선택하기 때문에 n*(n-1)/2 라고 할 수 있고, 여기서 적용을 하면 (left-right) * (left - right -1)/2 라고 할 수 있겠다.

 

또한 -5 (i)-4(right) -4 0 1(left) 1 이러한 수열이 있다면 -4가 2개이고 1이 2개 이기 때문에 각 수의 갯수에서 1개를 선택하는 조합 식을 사용하여 곱해야 한다. left 에서는 하나씩 인덱스를 더해가면서 동일한 부분을 살피고, right 부분에서는 인덱스를 하나씩 내려가는 식으로 살펴본다. 이를 l_flag과 r_flag로 정의하였다. 

 

else {
       if(arr[left] == arr[right]) {
           count += (left-right)*(left-right-1)/2;
       } else {
           int l_flag = 0;
           int r_flag = 0;
           while(arr[left+l_flag]==arr[left]) {
               l_flag += 1; 
           }
           while(arr[right-r_flag]==arr[right]) {
               r_flag +=1;
           }
           left = l_flag;
           right = r_flag;
           count += (l_flag) * (r_flag);
       }
    }

신경을 써서 다시 코드를 짜봤다.

 

최종적인 코드는 아래와 같고 이제 한번 돌려봐야겠다. 

import java.util.Arrays;
import java.util.Scanner;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int N = sc.nextInt();
        int arr[] = new int[N];

        for (int i = 0; i < N; i++) {
            arr[i] = sc.nextInt();
        }

        // 2 -5 2 3 -4 7 -4 0 1 6

        Arrays.sort(arr);
        
        // -5 -4 -4 0 1 2 2 3 6 7

        int count = 0;

        for (int i = 0; i < N; i++) {
            int ans = 0;
            ans += arr[i];

            int left = i+1;
            int right = N-1;

            while(left<right) {
                ans += arr[left] + arr[right];
                if (ans < 0) {
                    left += 1;
                } else if (ans > 0) {
                    right -= 1;
                } else {
                   if(arr[left] == arr[right]) {
                       count += (left-right)*(left-right-1)/2;
                   } else {
                       int l_flag = 0;
                       int r_flag = 0;
                       while(arr[left+l_flag]==arr[left]) {
                           l_flag += 1; 
                       }
                       while(arr[right-r_flag]==arr[right]) {
                           r_flag +=1;
                       }
                       left = l_flag;
                       right = r_flag;
                       count += (l_flag) * (r_flag);
                   }
                }
            }

        }
        System.out.println(count);
       }    
}