Notice
Recent Posts
Recent Comments
Link
일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | |||||
3 | 4 | 5 | 6 | 7 | 8 | 9 |
10 | 11 | 12 | 13 | 14 | 15 | 16 |
17 | 18 | 19 | 20 | 21 | 22 | 23 |
24 | 25 | 26 | 27 | 28 | 29 | 30 |
31 |
Tags
- Computer Architecture
- Dreamhack
- Widget
- Image Processing
- 영상처리
- Algorithm
- C++
- 백준
- study book
- pytorch
- bloc
- llm을 활용 단어장 앱 개발일지
- MDP
- BOF
- Flutter
- Got
- rao
- MATLAB
- BAEKJOON
- ARM
- Stream
- system hacking
- DART
- Kaggle
- BFS
- ML
- fastapi를 사용한 파이썬 웹 개발
- PCA
- 파이토치 트랜스포머를 활용한 자연어 처리와 컴퓨터비전 심층학습
- FastAPI
Archives
- Today
- Total
Bull
[Deep Learning] 코드를 보며 RNN 이해하기 본문
RNN(Recurrent Neural Network)은 시퀀스 데이터를 처리하기 위해 설계된 인공 신경망의 한 유형이다.
RNN은 텍스트, 시간 시계열 데이터, 오디오 신호 등 순차적인 데이터에 유용하다.
처음 사진만 보고 이해가 안갈 수도 있다. 나도 혼자 책으로 처음 접할 때 output이 여러 개인 것에 대해 이해를 못했다.
하지만 최근 학교 수업에서 간략하게 코드를 통해 수업을 했는데 예제를 보고 어떤 느낌인 지 확 와닿았다.
이렇게 이해하면 쉽다. [0,1,2 → 3], [1,2,3 → 4], [2,3,4 → 5] ... [1111,1112,1113 → 1114]를 학습시킨 후,
[53783,53784,53785 → ???] 에 대한 y값을 예측하는 것이다.
위의 예제에서 입력값은 3개지만, 더 늘릴 수도 있다.
연속된 패턴을 기억할 수 있기 때문에 위의 지식을 기반을 자연어에 활용이 가능하다.
자연어처럼 문장을 숫자로 벡터라이징한 후 연속된 문장을 학습시킨다면, 빈칸 뚫린 문장을 넣어 빈칸에 들어가는 단어를 예측할 수 있다.
하지만 SimpleRNN은 단순하기 때문에 여기서 더 확장된 버전인 LSTM, GRU를 통해 자연어 처리를 구사할 수 있다.
Code
Getting Imports¶
In [2]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import SimpleRNN, Dense, LSTM, GRU
Example sequence¶
In [3]:
sequence = np.linspace(0,1000,1001)
sequence
Out[3]:
array([ 0., 1., 2., ..., 998., 999., 1000.])
Prepare the dataset¶
In [4]:
def create_dataset(sequence, n_steps):
X, y = [], []
for i in range(len(sequence) - n_steps):
X.append(sequence[i:i + n_steps])
y.append(sequence[i + n_steps])
return np.array(X), np.array(y)
Get the Data¶
In [5]:
n_steps = 3
X, y = create_dataset(sequence, n_steps)
In [6]:
X
Out[6]:
array([[ 0., 1., 2.], [ 1., 2., 3.], [ 2., 3., 4.], ..., [995., 996., 997.], [996., 997., 998.], [997., 998., 999.]])
In [14]:
y
Out[14]:
array([ 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 43., 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., 55., 56., 57., 58., 59., 60., 61., 62., 63., 64., 65., 66., 67., 68., 69., 70., 71., 72., 73., 74., 75., 76., 77., 78., 79., 80., 81., 82., 83., 84., 85., 86., 87., 88., 89., 90., 91., 92., 93., 94., 95., 96., 97., 98., 99., 100., 101., 102., 103., 104., 105., 106., 107., 108., 109., 110., 111., 112., 113., 114., 115., 116., 117., 118., 119., 120., 121., 122., 123., 124., 125., 126., 127., 128., 129., 130., 131., 132., 133., 134., 135., 136., 137., 138., 139., 140., 141., 142., 143., 144., 145., 146., 147., 148., 149., 150., 151., 152., 153., 154., 155., 156., 157., 158., 159., 160., 161., 162., 163., 164., 165., 166., 167., 168., 169., 170., 171., 172., 173., 174., 175., 176., 177., 178., 179., 180., 181., 182., 183., 184., 185., 186., 187., 188., 189., 190., 191., 192., 193., 194., 195., 196., 197., 198., 199., 200., 201., 202., 203., 204., 205., 206., 207., 208., 209., 210., 211., 212., 213., 214., 215., 216., 217., 218., 219., 220., 221., 222., 223., 224., 225., 226., 227., 228., 229., 230., 231., 232., 233., 234., 235., 236., 237., 238., 239., 240., 241., 242., 243., 244., 245., 246., 247., 248., 249., 250., 251., 252., 253., 254., 255., 256., 257., 258., 259., 260., 261., 262., 263., 264., 265., 266., 267., 268., 269., 270., 271., 272., 273., 274., 275., 276., 277., 278., 279., 280., 281., 282., 283., 284., 285., 286., 287., 288., 289., 290., 291., 292., 293., 294., 295., 296., 297., 298., 299., 300., 301., 302., 303., 304., 305., 306., 307., 308., 309., 310., 311., 312., 313., 314., 315., 316., 317., 318., 319., 320., 321., 322., 323., 324., 325., 326., 327., 328., 329., 330., 331., 332., 333., 334., 335., 336., 337., 338., 339., 340., 341., 342., 343., 344., 345., 346., 347., 348., 349., 350., 351., 352., 353., 354., 355., 356., 357., 358., 359., 360., 361., 362., 363., 364., 365., 366., 367., 368., 369., 370., 371., 372., 373., 374., 375., 376., 377., 378., 379., 380., 381., 382., 383., 384., 385., 386., 387., 388., 389., 390., 391., 392., 393., 394., 395., 396., 397., 398., 399., 400., 401., 402., 403., 404., 405., 406., 407., 408., 409., 410., 411., 412., 413., 414., 415., 416., 417., 418., 419., 420., 421., 422., 423., 424., 425., 426., 427., 428., 429., 430., 431., 432., 433., 434., 435., 436., 437., 438., 439., 440., 441., 442., 443., 444., 445., 446., 447., 448., 449., 450., 451., 452., 453., 454., 455., 456., 457., 458., 459., 460., 461., 462., 463., 464., 465., 466., 467., 468., 469., 470., 471., 472., 473., 474., 475., 476., 477., 478., 479., 480., 481., 482., 483., 484., 485., 486., 487., 488., 489., 490., 491., 492., 493., 494., 495., 496., 497., 498., 499., 500., 501., 502., 503., 504., 505., 506., 507., 508., 509., 510., 511., 512., 513., 514., 515., 516., 517., 518., 519., 520., 521., 522., 523., 524., 525., 526., 527., 528., 529., 530., 531., 532., 533., 534., 535., 536., 537., 538., 539., 540., 541., 542., 543., 544., 545., 546., 547., 548., 549., 550., 551., 552., 553., 554., 555., 556., 557., 558., 559., 560., 561., 562., 563., 564., 565., 566., 567., 568., 569., 570., 571., 572., 573., 574., 575., 576., 577., 578., 579., 580., 581., 582., 583., 584., 585., 586., 587., 588., 589., 590., 591., 592., 593., 594., 595., 596., 597., 598., 599., 600., 601., 602., 603., 604., 605., 606., 607., 608., 609., 610., 611., 612., 613., 614., 615., 616., 617., 618., 619., 620., 621., 622., 623., 624., 625., 626., 627., 628., 629., 630., 631., 632., 633., 634., 635., 636., 637., 638., 639., 640., 641., 642., 643., 644., 645., 646., 647., 648., 649., 650., 651., 652., 653., 654., 655., 656., 657., 658., 659., 660., 661., 662., 663., 664., 665., 666., 667., 668., 669., 670., 671., 672., 673., 674., 675., 676., 677., 678., 679., 680., 681., 682., 683., 684., 685., 686., 687., 688., 689., 690., 691., 692., 693., 694., 695., 696., 697., 698., 699., 700., 701., 702., 703., 704., 705., 706., 707., 708., 709., 710., 711., 712., 713., 714., 715., 716., 717., 718., 719., 720., 721., 722., 723., 724., 725., 726., 727., 728., 729., 730., 731., 732., 733., 734., 735., 736., 737., 738., 739., 740., 741., 742., 743., 744., 745., 746., 747., 748., 749., 750., 751., 752., 753., 754., 755., 756., 757., 758., 759., 760., 761., 762., 763., 764., 765., 766., 767., 768., 769., 770., 771., 772., 773., 774., 775., 776., 777., 778., 779., 780., 781., 782., 783., 784., 785., 786., 787., 788., 789., 790., 791., 792., 793., 794., 795., 796., 797., 798., 799., 800., 801., 802., 803., 804., 805., 806., 807., 808., 809., 810., 811., 812., 813., 814., 815., 816., 817., 818., 819., 820., 821., 822., 823., 824., 825., 826., 827., 828., 829., 830., 831., 832., 833., 834., 835., 836., 837., 838., 839., 840., 841., 842., 843., 844., 845., 846., 847., 848., 849., 850., 851., 852., 853., 854., 855., 856., 857., 858., 859., 860., 861., 862., 863., 864., 865., 866., 867., 868., 869., 870., 871., 872., 873., 874., 875., 876., 877., 878., 879., 880., 881., 882., 883., 884., 885., 886., 887., 888., 889., 890., 891., 892., 893., 894., 895., 896., 897., 898., 899., 900., 901., 902., 903., 904., 905., 906., 907., 908., 909., 910., 911., 912., 913., 914., 915., 916., 917., 918., 919., 920., 921., 922., 923., 924., 925., 926., 927., 928., 929., 930., 931., 932., 933., 934., 935., 936., 937., 938., 939., 940., 941., 942., 943., 944., 945., 946., 947., 948., 949., 950., 951., 952., 953., 954., 955., 956., 957., 958., 959., 960., 961., 962., 963., 964., 965., 966., 967., 968., 969., 970., 971., 972., 973., 974., 975., 976., 977., 978., 979., 980., 981., 982., 983., 984., 985., 986., 987., 988., 989., 990., 991., 992., 993., 994., 995., 996., 997., 998., 999., 1000.])
Reshape from [samples, timesteps] to [samples, timesteps, features]¶
In [19]:
X = X.reshape((X.shape[0], X.shape[1], 1))
X
Out[19]:
array([[[ 0.], [ 1.], [ 2.]], [[ 1.], [ 2.], [ 3.]], [[ 2.], [ 3.], [ 4.]], ..., [[995.], [996.], [997.]], [[996.], [997.], [998.]], [[997.], [998.], [999.]]])
Build the RNN model¶
In [35]:
model = Sequential()
model.add(SimpleRNN(50, activation='relu', input_shape=(n_steps, 1)))
model.add(Dense(1))
model.compile(optimizer='adam', loss='mse')
c:\Users\Bu11\anaconda3\Lib\site-packages\keras\src\layers\rnn\rnn.py:204: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead. super().__init__(**kwargs)
Train the model¶
In [36]:
model.fit(X, y, epochs=200, verbose=0)
Out[36]:
<keras.src.callbacks.history.History at 0x2387b5d7810>
Demonstrate prediction¶
In [38]:
x_input = np.array([15, 16, 17])
x_input = x_input.reshape((1, n_steps, 1))
yhat = model.predict(x_input, verbose=0)
print(yhat)
[[18.032255]]
'Artificial Intelligence > Deep Learning' 카테고리의 다른 글
[DL] RNN (Recurrent Neural Network) | study books (1) | 2024.09.28 |
---|---|
[DL] MobileNet 요약 (0) | 2024.08.05 |
[DL/유머] 인공지능이 SoftMax(소맥)먹고 취하는 사진 (1) | 2024.03.17 |
[DL/CNN] ResNet (Residual Network) (0) | 2024.03.16 |
[DL] 순전파의 bias와 MDP의 reward의 차이 (0) | 2024.03.06 |