pbj0812의 코딩 일기

[TensorFlow] tensorflow.js 로 학습한 모델 저장하기 본문

인공지능 & 머신러닝/TensorFlow

[TensorFlow] tensorflow.js 로 학습한 모델 저장하기

pbj0812 2021. 1. 7. 04:29

0. 목표

 - tensorflow.js 로 학습한 모델 저장하기

1. main.js

var http = require('http');
var fs = require('fs');
var app = http.createServer(function(request,response){
    var url = request.url;
    if(request.url == '/'){
      url = '/save.html';
    }
    if(request.url == '/favicon.ico'){
      return response.writeHead(404);
    }
    response.writeHead(200);
    response.end(fs.readFileSync(__dirname + url));
 
});
app.listen(3001);

2. save.html

<!DOCTYPE html>
<html>
 
<head>
    <title>TensorFlow.js Tutorial - lemon</title>
 
    <!-- Import TensorFlow.js -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@1.0.0/dist/tf.min.js"></script>
     
</head>
 
<body>
    <script>
        // 1. 데이터 준비합니다. 
        var xx = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
        var yy = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11];
        var reason = tf.tensor(xx);
        var result = tf.tensor(yy);
 
        // 2. 모델의 모양을 만듭니다. 
        var X = tf.input({ shape: [1] });
        var Y = tf.layers.dense({ units: 1 }).apply(X);
        var model = tf.model({ inputs: X, outputs: Y });
        var compileParam = { optimizer: tf.train.adam(), loss: tf.losses.meanSquaredError }
        model.compile(compileParam);
 
        // 3. 데이터로 모델을 학습시킵니다. 
        console.log('model run');
        var fitParam = { 
            epochs: 100,
            callbacks: {
                onEpochEnd:
                    function(epoch, logs){
                        console.log('epoch', epoch, logs, 'RMSE : ', Math.sqrt(logs.loss));
                    }
            }
        } 
        
        console.log('model save');
        model.fit(reason, result, fitParam).then(function (result) {
            model.save('localstorage://my_model');
        });
    </script>
</body>
 
</html>

3. 실행

node main.js

4. 결과

 - localhost:3001

 - 검사

 - 저장된 모델은 상단 Application -> 좌측 Storage -> Local Storage

5. load.html

<!DOCTYPE html>
<html>
 
<head>
    <title>TensorFlow.js Tutorial - lemon</title>
 
    <!-- Import TensorFlow.js -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@1.0.0/dist/tf.min.js"></script>
     
</head>
 
<body>
    <script>
        console.log('model load');
        tf.loadLayersModel('localstorage://my_model').then(function(model){
            var weights = model.getWeights();
            var weight = weights[0].arraySync()[0][0];
            var bias = weights[1].arraySync()[0];
            console.log(weight);
            console.log(bias);
        });
    </script>
</body>
 
</html>

6. 실행

 - main.js 에서 save.html -> load.html 로 변경

node main.js

7. 결과

 - 페이지를 끄게 되면 데이터가 날아갈 수 있기에 새로고침

 - localhost:3001

8. 참고

 - 모델의 저장과 불러오기

 - 모델 저장 및로드

Comments