자바스크립트에서 딥러닝 인공지능 자동 미분 구현, 편미분 구현 - 2 - Whitmem
자바스크립트에서 딥러닝 인공지능 자동 미분 구현, 편미분 구현 - 2
AI Development Study
2025-05-01 01:16 게시 9cbf21b53d7e2aa4855a

0
0
31
이 페이지는 외부 공간에 무단 복제할 수 없으며 오직 있는 그대로 게시되며 부정확한 내용을 포함할 수 있습니다. 법률이 허용하는 한 가이드 라인에 맞춰 게시 내용을 인용하거나 출처로 표기할 수 있습니다.
This page is not to be distributed to external services; it is provided as is and may contain inaccuracies.
자동 미분 구현을 위한 계산 클래스 구현
Torch 등에서는 Tensor 내부에 자동 미분 기능이 내장되어있다. x^2 + x*4 + x*50 의 식을 계산하는 과정을 나타낸 것이다. x 미지수에는 3이라는 값을 넣어 연산을 진행한다.
기본적으로 파이썬은 연산자 오버라이딩이 가능하기 때문에, 위와 같은 연산은 자동으로 미분 처리가 진행된다. 이는 https://whitmem.kr/read/438 에서 자동 미분의 원리가 어떻게 되는지 상세하게 설명되어 있으니, 해당 게시글을 참고하길 바란다.
아무튼, 위 자동 미분 기능을 Javascript 웹 브라우저에서 구현해보는 것으로 한다. 안타깝게도 Javascript 에서는 연산자 오버라이딩을 지원하지 않는다. 즉 파이썬에서 가능하던 연산자 오버라이딩, Tensor 와 숫자간 +, -, *, / 의 연산을 미분 가능한 형태로 인스턴스로 관리하고 텐서로 처리하는 기능을 Javascript 로는 구현할 수 없다.
따라서, 메서드로 연산이 가능하게끔 구현해야 하며, 예를 들어 Tensor 로 가정하자면, 아래와 같이 연산을 수행해야 한다.
x = Tensor(3.0) result = x.multiply(x); result = result.add(x.multiply(Tensor(4.0)) result = result.add(x.multiply(Tensor(50))
즉 연산자만 사용하지 않을 뿐이지, 메서드로 구현하면 모두 같다.
자바스크립트에서 사용할 클래스 명은 Comp 로 구현하였고, 단일 값에 대해서만 우선 연산이 가능하도록 구현하였다. 참고로 이 게시글은 자바스크립트에서 사용할 수 있는 코드를 게시하는 목적이 아닌, 구현의 과정을 나타내는 게시글이라 소스코드를 복사 붙여넣기 할 수 있게끔 별도로 게시하지 아니한다. (실제 실용성이 없다..)
생성자는 단일 값 value를 보관할 수 있으며, 내부적으로 미분 정보를 보관한다. 이외 역전파를 위한 함수 정보를 보관, 이전 노드 정보 parentNodes를 보관한다.
먼저 첫 번째로 덧셈 구현이다. 어떤 노드와 덧셈을 하기 위해서, 덧셈 연산 그래프의 과정을 다시 본다.
단순 두 노드의 덧셈에서 미분 값은 두 개의 항 각각 1이다. 따라서 미분 값은 둘다 1이 되고, 이전 노드에서 넘어온 미분 값을 곱해주기만 하면 된다.
즉 add 하는 self (this) 와, 연산 대상 comp 각각 1의 미분 값을 넘겨주되, 두 연산으로 인해 생긴 + 노드에 대한 미분 값을 같이 곱해서 넘어가도록 한다.
다만 유의할 것은, 위 연산은 즉시 실행되는 것이 아니라, backward 당시에 연산되는 것으로 여기서는 정의만 하는 것이다. 즉 당연히 위 코드에서 새로운 덧셈 노드 결과인 result 는 방금 생성되었기 때문에 당연히 미분 값이 존재하지 않는다. 그렇기에 이를 나중에 실행할 수 있는 함수와 정의 명령만 만들어두고 실제 실행은 역전파시에만 실행된다.
역전파는 노드의 모든 순전파가 완료되어야 진행할 수 있기 때문에 역전파 함수를 만들어두고 실행할 수 있는 대기열에 넣어둘 뿐이다.
그리고 새로 생성되는 연산 결과의 부모 노드로 각 덧셈의 두 항 노드를 기록하는데, 이 이유는 새로 생성된 연산 결과 노드를 역전파할 때 이전 노드로 이어지게끔 하기 위해 메모리에 각 노드를 연결짓는 것이다. 역전파 과정은 거꾸로 돌아오면서 첫 노드까지 순회되어야 하기 때문에, 그 정보를 모두 각 노드에 잇는 것이다. 이러한 작업이 없으면 각 연산마다 일일이 backward() 를 수행해야 한다.
다음은 곱셈 노드이다. 곱셈 노드도 덧셈 노드와 원리는 동일하다. 다만 미분 값이 각각 서로 다른 항의 값이다.
x * y 에서, x로 편미분하면 y를 넣고, y로 편미분하면 x를 넣기 때문에 새로 생성되는 연산 노드의 역전파 함수에, this(x) 미분 값에는 y값을 넣고, y 미분값에는 this(x)를 넣도록 프로그래밍한다.
이외 나눗셈 노드와 뺄셈 노드는 설명을 생략한다. 이 역시 각 편미분한 값을 서로 다른 출처 노드의 미분 값에 넣어주는 식으로 프로그래밍하면 된다.
참고로, requestGrid는 별거없다. 단순히 미분 값을 더할 뿐이다.
즉 위와 같은 연산 노드를 만들었고, add, multiply, divide 를 사용해 노드를 이었다고 가정해보자. 우리가 화살표로 표현을 해서 그렇지, 각 인스턴스를 서로 연결된 상태가 아니다.
사실상 컴퓨터에서는 a+b 값을 A로 꺼낸 다음, A 값을 다시 곱셈 노드에 넣어 B*c 를 연산하고, 나온 C 값을 다시 / 노드에 넣어 나눗셈을 진행한다. 즉 논리적으로는 연결된 관계이지만 프로그래밍상으로는 연산 결과를 다시 넘길뿐이다. 즉 역전파를 하기 위해서는 각 노드와의 연결 정보가 있어야 역전파를 연쇄적으로 수행할 수 있다. 예를 들어 E에서 backward()를 한다고 가정하자,
E 에서 backward()를 수행하면 이전 노드로 이어진 모든 backward() 명령을 연쇄적으로 보내 미분처리를 수행해야 한다. 여기서 backward()라 함은, 각각 add(), multiply(), divide(), subtract()를 수행할 때 생성한 역전파 명령이다.
여기서는 executeBackward()로 정의하였다. 즉 메인 함수인 backward() 를 요청하면 해당 지점에서 이전 노드를 모두 탐색해서 각각 노드에 순차적으로 executeBackward() 를 수행해주는 것이다. 유의 해야할 것이 순서도 중요하다. 역전파를 수행할 때 맨 앞부터 backward()가 실행되어서는 안된다. 각각 한 노드에 대해서는 미분을 수행할 수 있지만 보통 시작지점에서 미분을 시작해 나가야하기 때문에, (즉 시작지점의 미분 값이 연쇄적으로 전달되어 곱해지기 때문에) 각각 노드의 미분 값 * NextNode 의 미분결괏값이 부모 노드의 미분 값에 반영된다. 즉 마지막 노드부터 순차적으로 미분이 되어가면서 첫 시작(순전파 시작)의 노드로 미분 값이 전달되어야 한다. 따라서 정렬 또한 역전파 순서대로 정렬되어야 한다.
실제 역전파 함수는 executeBackward()를 각각 노드에 실행하기 전에 해당 노드로부터 순전파 시작까지 노드를 모두 탐색한다. 우선 탐색할 때는 스택 순서대로 탐색을 진행한다.
위 노드에서 순서를 매겨보자. 1,2,3,4 노드가 각각 유기적으로 연결되어있고, 순서는 1->2->3 이고, 4->3 이다. 3에서 최종 값이 출력된다. 이를 후방 탐색을 기준으로 (즉 재귀함수로 탐색해도 된다. 다만 여기서는 while문으로 탐색하기 위해 위와 같이 구현하였다.) 정렬하면, 3->2->1->4 가 되고, 이를 reverse 하면 후방 탐색 결과가 도출된다. 즉 왼쪽을 먼저 파고 들되, 안쪽을 먼저 방문하는 것이다.
아무튼 노드가 깊이 순서대로 정렬되는 것을 확인할 수 있는데 여기서 정렬된 순서대로 executeBackward()를 하면된다. 다만 고려할 점은
위 계산에서 Hidden_Node1 이 중복적으로 계산 노드 여러곳에서 사용중일 때, 위 itemList 노드 순서대로라면, 3 역전파, 2역전파, 1역전파, 4역전파가 진행되는데 같은 노드를 사용하는게 매우 많을 경우 연산 속도 저하로 이어질 수 있다. 따라서 같은 연산끼리는 미분 값을 모아두다가, 최종 Hidden_Node1 탐색시에만 내부로 빠져들 수 있게끔 별도 작업을 해 주었는데, 그것이 바로 다음 부분이다.
즉 중복 처리 부분을 방지하는 부분이다. 중복 처리를 안해도 계산 결과는 동일하게 나오나... 이래도 되는지는 잘 모르겠다. 아무튼 중복 처리를 방지하였다. 중복을 방지하기 위해 같은 연산 노드는 무조건 역전파시 마지막 탐색의 노드에서만 해당 노드의 부모로 파고 들게끔 구현을 해야한다. 일부 노드를 잘못 중복 처리하여 제외해버리면 미분 값이 덜 더해질 수 있다.
이제 위 클래스를 사용해서 역전파를 수행해본다,
메인에서는 임의 식을 연산해본다. cmp 라는 변수(x)를 하나 만들어 3이라는 값을 넣어두고, x**2 + 4*x + 50x 를 연산한다.
alert(result.getValue());
result.getValue() 값 자체는 171이 나온다. 위 식의 연산 결과이다. 미분 값은 어떻게 나올까?
미분을 수행하기 위해 최종 result에 대해 backward를 수행하고, x 값 3.0 을 가지고 있는 cmp 변수에 대해 grid 값을 출력해본다. 60이 출력되는 것을 확인할 수 있다. 미분 식인 2x + 4 + 50 에 3을 넣었을 때 60이 나오는 것을 확인할 수 있다. 즉 어떤 식도 쉽게 미분할 수 있는 자동 미분 클래스를 만든 것이다.
댓글 0개
댓글은 일회용 패스워드가 발급되며 사이트 이용 약관에 동의로 간주됩니다.
확인
Whitmemit 개인 일지 블로그는 개인이 운영하는 정보 공유 공간으로 사용자의 민감한 개인 정보를 직접 요구하거나 요청하지 않습니다. 기본적인 사이트 방문시 처리되는 처리 정보에 대해서는 '사이트 처리 방침'을 참고하십시오. 추가적인 기능의 제공을 위하여 쿠키 정보를 사용하고 있습니다. Whitmemit 에서 처리하는 정보는 식별 용도로 사용되며 기타 글꼴 및 폰트 라이브러리에서 쿠키 정보를 사용할 수 있습니다.
이 자료는 모두 필수 자료로 간주되며, 사이트 이용을 하거나, 탐색하는 경우 동의로 간주합니다.