Skip to content

保存文件

将会话保存为文件:

python
saver = tf.train.Saver()
saver.save(sess, path=path)

加载会话:

python
saver = tf.train.Saver()
sess = tf.Session()
saver.restore(sess, path=path)

使用 TensorBoard

保存对话图形信息:

python
writer = tf.summary.FileWriter('./graph', sess.graph)

启动 TensorBoard:

bash
tensorboead --logdir=${logdir}

为变量定义名称以便在 TensorBoard 中展示:

python
x = tf.placeholder(shape=(1, 3), dtype=tf.float32, name='x')

预测结果

python
res = sess.run([x, y], feed_dict={x: predict_data})

保存模型以供别的语言使用

python
builder = tf.saved_model.builder.SavedModelBuilder('./export')
builder.add_meta_graph_and_variables(sess, ['tag'])
builder.save()

Java 调用 TensorFlow 模型

在 Windows 下确保可以调用 libtensorflow-*.jartensorflow_jni.dll

java
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlow;
import org.tensorflow.SavedModelBundle;
import java.nio.FloatBuffer;
import java.util.Arrays;

public class TensorFlowTest {
    public static void main(String[] args) {
        SavedModelBundle savedModelBundle = SavedModelBundle.load("export", "tag");
        Session sess = savedModelBundle.session();

        float[][] matrix = {{1.0f, 2.0f, 3.0f, 4.0f}};
        System.out.println(Arrays.deepToString(matrix));

        Tensor xFeed = Tensor.create(matrix);
        Tensor result = sess.runner().feed("x", xFeed).fetch("y").run().get(0);
        FloatBuffer buffer = FloatBuffer.allocate(2);

        result.wirteTo(buffer);
        System.out.println(result.toString());
        System.out.println(buffer.get(0));
        System.out.println(buffer.get(1));
    }
}

Go 版本

需要安装 Go 版本的 TensorFlow,并输入下面命令安装

bash
go get github.com/tensorflow/tensorflow/tensorflow/go
go
package main

import (
    "fmt",
    tg "github.com/galeone/tfgo"
    tf "github.com/tersorflow/tensorflow/go"
)

func main() {
    model := tg.LoadModel("export", []string{"tag"}, nil)
    inputArray := [][]float32{{1, 2, 3, 4}}
    fmt.Printf("input: %v\n", inputArray)

    fakeInput, _ := tf.NewTensor(inputArray)
    result := model.Exec([]tf.Output{
        model.Op("y", 0),
    }, map[tf.Output] * tf.Tensor{
        model.Op("x", 0): fakeInput
    })

    predict := result[0].Value().([]float32)
    fmt.Printf("predict: %v\n", predict)
}