保存文件
将会话保存为文件:
python
saver = tf.train.Saver()
saver.save(sess, path=path)加载会话:
python
saver = tf.train.Saver()
sess = tf.Session()
saver.restore(sess, path=path)使用 TensorBoard
保存对话图形信息:
python
writer = tf.summary.FileWriter('./graph', sess.graph)启动 TensorBoard:
bash
tensorboead --logdir=${logdir}为变量定义名称以便在 TensorBoard 中展示:
python
x = tf.placeholder(shape=(1, 3), dtype=tf.float32, name='x')预测结果
python
res = sess.run([x, y], feed_dict={x: predict_data})保存模型以供别的语言使用
python
builder = tf.saved_model.builder.SavedModelBuilder('./export')
builder.add_meta_graph_and_variables(sess, ['tag'])
builder.save()Java 调用 TensorFlow 模型
在 Windows 下确保可以调用 libtensorflow-*.jar 和 tensorflow_jni.dll:
java
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlow;
import org.tensorflow.SavedModelBundle;
import java.nio.FloatBuffer;
import java.util.Arrays;
public class TensorFlowTest {
public static void main(String[] args) {
SavedModelBundle savedModelBundle = SavedModelBundle.load("export", "tag");
Session sess = savedModelBundle.session();
float[][] matrix = {{1.0f, 2.0f, 3.0f, 4.0f}};
System.out.println(Arrays.deepToString(matrix));
Tensor xFeed = Tensor.create(matrix);
Tensor result = sess.runner().feed("x", xFeed).fetch("y").run().get(0);
FloatBuffer buffer = FloatBuffer.allocate(2);
result.wirteTo(buffer);
System.out.println(result.toString());
System.out.println(buffer.get(0));
System.out.println(buffer.get(1));
}
}Go 版本
需要安装 Go 版本的 TensorFlow,并输入下面命令安装
bash
go get github.com/tensorflow/tensorflow/tensorflow/gogo
package main
import (
"fmt",
tg "github.com/galeone/tfgo"
tf "github.com/tersorflow/tensorflow/go"
)
func main() {
model := tg.LoadModel("export", []string{"tag"}, nil)
inputArray := [][]float32{{1, 2, 3, 4}}
fmt.Printf("input: %v\n", inputArray)
fakeInput, _ := tf.NewTensor(inputArray)
result := model.Exec([]tf.Output{
model.Op("y", 0),
}, map[tf.Output] * tf.Tensor{
model.Op("x", 0): fakeInput
})
predict := result[0].Value().([]float32)
fmt.Printf("predict: %v\n", predict)
}