[Paper Review] NeRF Code Review

[Paper Review] NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis

Input Data

transforms_train.json

camera_angle_x : The FOV in x dimension
frames : List of dictionaries that contain the camera transform matrices for each image.  camera-to-world matrix

삼각법을 이용한 초점 거리 계산

focalLength = get_focal_from_fov(fieldOfView=traindata["camera_angle_x"], width=config.IMAGE_WIDTH)

1
2
def get_focal_from_fov(fieldOfView, width):
	return 0.5 * width / tf.tan(0.5 * fieldOfView)

tf.tan 텐서에 있는 모든 요소의 탄젠트를 계산

그림2.png

이미지의 경로와 camera to world transform matrix 얻기

trainImagePaths, trainC2Ws = get_image_c2w(jsonData=traindata, datasetPath=config.DATASET_PATH)

1
2
3
4
5
6
7
8
9
10
11
def get_image_c2w(jsonData, datasetPath):
	imagePaths = []

	c2ws = []  # Pose Matrics -> Camera to World Matrices

	for frame in jsonData["frames"]:
		imagePath = frame["file_path"]
		imagePath = imagePath.replace(".", datasetPath)
		imagePaths.append(f"{imagePath}.png")
		c2ws.append(frame["transform_matrix"])
	return (imagePaths, c2ws)

디스크에서 이미지를 로드하는 데 사용되는 클래스의 개체를 인스턴스화

getImages = GetImages(imageHeight=config.IMAGE_HEIGHT, imageWidth=config.IMAGE_WIDTH)

trainImageDs = (tf.data.Dataset.from_tensor_slices(trainImagePaths).map(getImages, num_parallel_calls=config.AUTO))

1
2
3
4
5
6
7
8
9
10
11
12
class GetImages():
	def __init__(self, imageWidth, imageHeight):
		self.imageWidth = imageWidth
		self.imageHeight = imageHeight

	def __call__(self, imagePath):
		image = tf.io.read_file(imagePath)
		image = tf.io.decode_jpeg(image, 3)
		image = tf.image.convert_image_dtype(image, dtype=tf.float32)
		image = tf.image.resize(image, (self.imageWidth, self.imageHeight))
		image = reshape(image, (self.imageWidth, self.imageHeight, 3))
		return image

tf.io.decode_jpeg JPEG 인코딩 이미지를 uint8 텐서로 디코딩

tf.data.Dataset.from_tensor_slices tf.data.Dataset 를 생성하는 함수로 입력된 텐서로부터 slices를 생성

ray의 object 인스턴스화

getRays = GetRays(focalLength=focalLength, imageWidth=config.IMAGE_WIDTH, imageHeight=config.IMAGE_HEIGHT, NEAR_BOUNDS=config.NEAR_BOUNDS, FAR_BOUNDS=config.FAR_BOUNDS, nC=config.NUM_COARSE)

trainRayDs = (tf.data.Dataset.from_tensor_slices(trainC2Ws).map(getRays, num_parallel_calls=config.AUTO))

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
class GetRays:	
	def __init__(self, focalLength, imageWidth, imageHeight, NEAR_BOUNDS, 
		FAR_BOUNDS, nC):
		self.focalLength = focalLength
		self.imageWidth = imageWidth
		self.imageHeight = imageHeight
		self.NEAR_BOUNDS = NEAR_BOUNDS
		self.FAR_BOUNDS = FAR_BOUNDS
		self.nC = nC

	def __call__(self, camera2world):
		# Creating a meshgrid for rays
		(x, y) = tf.meshgrid(
			tf.range(self.imageWidth, dtype=tf.float32),
			tf.range(self.imageHeight, dtype=tf.float32),
			indexing="xy",
		)

		# Define the camera coordinates
		xCamera = (x - self.imageWidth * 0.5) / self.focalLength
		yCamera = (y - self.imageHeight * 0.5) / self.focalLength

		# Define the camera vector
		xCyCzC = tf.stack([xCamera, -yCamera, -tf.ones_like(x)], axis=-1)

		# Slice the camera2world matrix to obtain the roataion and
		# Translation matrix
		rotation = camera2world[:3, :3]
		translation = camera2world[:3, -1]

		# Expand the camera coordinates to 
		xCyCzC = xCyCzC[..., None, :]
		
		# Get the world coordinates
		xWyWzW = xCyCzC * rotation
		
		# Calculate the direciton vector of the ray
		rayD = tf.reduce_sum(xWyWzW, axis=-1)
		rayD = rayD / tf.norm(rayD, axis=-1, keepdims=True)

		# Calculate the origin vector of the ray
		rayO = tf.broadcast_to(translation, tf.shape(rayD))

		# Get the sample points from the ray
		tVals = tf.linspace(self.NEAR_BOUNDS, self.FAR_BOUNDS, self.nC)
		noiseShape = list(rayO.shape[:-1]) + [self.nC]
		noise = (tf.random.uniform(shape=noiseShape) * (self.FAR_BOUNDS - self.NEAR_BOUNDS) / self.nC)
		tVals = tVals + noise

		# Return origin, direction of the ray and the sample points
		return (rayO, rayD, tVals)

https://keras.io/examples/vision/nerf/

  • tf.reduce_sum 특정 차원을 제거한 후 구한 합계

  • tf.norm 벡터, 행렬 및 텐서의 표준을 계산

  • tf.linspace 주어진 축을 따라 일정한 간격의 값을 생성합니다

이미지 및 ray 데이터 세트를 함께 압축

: build data input pipeline for train, val, and test datasets

traindata = tf.data.Dataset.zip((trainRayDs, trainImageDs))

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
traindata = (
	traindata
	.shuffle(config.BATCH_SIZE)
	.batch(config.BATCH_SIZE)
	.repeat()
	.prefetch(config.AUTO)
)
valdata = (	
	valdata
	.shuffle(config.BATCH_SIZE)
	.batch(config.BATCH_SIZE)
	.repeat()
	.prefetch(config.AUTO)
)
testdata = (
	testdata
	.batch(config.BATCH_SIZE)
	.prefetch(config.AUTO)
)

Set Model

Instantiation of the Coarse Model & Fine Model

coarseModel = get_model(lxyz=config.DIMS_XYZ, lDir=config.DIMS_DIR, batchSize=config.BATCH_SIZE, denseUnits=config.UNITS, skipLayer=config.SKIP_LAYER)

fineModel = get_model(lxyz=config.DIMS_XYZ, lDir=config.DIMS_DIR, batchSize=config.BATCH_SIZE, denseUnits=config.UNITS, skipLayer=config.SKIP_LAYER)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def get_model(lxyz, lDir, batchSize, denseUnits, skipLayer):
	rayInput = Input(shape=(None, None, None, 2 * 3 * lxyz + 3), batch_size=batchSize)
	dirInput = Input(shape=(None, None, None, 2 * 3 * lDir + 3), batch_size=batchSize)
	
	x = rayInput
	for i in range(8):
		x = Dense(units=denseUnits, activation="relu")(x)
		if i % skipLayer == 0 and i > 0:
			x = concatenate([x, rayInput], axis=-1)
	
	sigma = Dense(units=1, activation="relu")(x)
	feature = Dense(units=denseUnits)(x)

	feature = concatenate([feature, dirInput], axis=-1)

	x = Dense(units=denseUnits//2, activation="relu")(feature)

	rgb = Dense(units=3, activation="sigmoid")(x)
	nerfModel = Model(inputs=[rayInput, dirInput], outputs=[rgb, sigma])
	return nerfModel

그림2.png

NerF trainer model

nerfModel = Nerf_Trainer(coarseModel=coarseModel, fineModel=fineModel, lxyz=config.DIMS_XYZ, lDir=config.DIMS_DIR, encoderFn=encoder_fn, renderImageDepth=render_image_depth, samplePdf=sample_pdf, nF=config.NUM_FINE)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class Nerf_Trainer(tf.keras.Model):
	def __init__(self, coarseModel, fineModel, lxyz, lDir, encoderFn, renderImageDepth, samplePdf, nF):
		super().__init__()
		self.coarseModel = coarseModel
		self.fineModel = fineModel
		
		self.lxyz = lxyz
		self.lDir = lDir

		self.encoderFn = encoderFn

		# Define the volume rendering function
		self.renderImageDepth = renderImageDepth

		# Define the hierarchical sampling function and the number of samples for the fine model
		self.samplePdf = samplePdf
		self.nF = nF

	def compile(self, optimizerCoarse, optimizerFine, lossFn):
		super().compile()
		self.optimizerCoarse = optimizerCoarse
		self.optimizerFine = optimizerFine
		self.lossFn = lossFn
		
		self.lossTracker = Mean(name="loss")
		self.psnrMetric = Mean(name="psnr")
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def train_step(self, inputs):
		# Image Getter and Ray Getter
		(elements, images) = inputs
		(raysOriCoarse, raysDirCoarse, tValsCoarse) = elements

		# Coarse Ray Generation
		raysCoarse = (raysOriCoarse[..., None, :] + 
			(raysDirCoarse[..., None, :] * tValsCoarse[..., None]))

		# Positional Encoding of the Ray and the Directions
		raysCoarse = self.encoderFn(raysCoarse, self.lxyz)
		dirCoarseShape = tf.shape(raysCoarse[..., :3])
		dirsCoarse = tf.broadcast_to(raysDirCoarse[..., None, :], shape=dirCoarseShape)
		dirsCoarse = self.encoderFn(dirsCoarse, self.lDir)

		# Keeping track of the Gradient
		with tf.GradientTape() as coarseTape:
			# compute the predictions from the coarse model
			(rgbCoarse, sigmaCoarse) = self.coarseModel([raysCoarse, dirsCoarse])
			
			# render the image from the predicitons
			renderCoarse = self.renderImageDepth(rgb=rgbCoarse,sigma=sigmaCoarse, tVals=tValsCoarse)
			(imagesCoarse, _, weightsCoarse) = renderCoarse

			# compute the photometric loss
			lossCoarse = self.lossFn(images, imagesCoarse)

		# compute the middle values of t vals
		tValsCoarseMid = (0.5 * (tValsCoarse[..., 1:] + tValsCoarse[..., :-1]))

		# apply hierarchical sampling and get the t vals for the fine model
		tValsFine = self.samplePdf(tValsMid=tValsCoarseMid, weights=weightsCoarse, nF=self.nF)
		tValsFine = tf.sort(tf.concat([tValsCoarse, tValsFine], axis=-1), axis=-1)

		# build the fine rays and positional encode it
		raysFine = (raysOriCoarse[..., None, :] + (raysDirCoarse[..., None, :] * tValsFine[..., None]))
		raysFine = self.encoderFn(raysFine, self.lxyz)
		
		# build the fine direcitons and positional encode it
		dirsFineShape = tf.shape(raysFine[..., :3])
		dirsFine = tf.broadcast_to(raysDirCoarse[..., None, :], shape=dirsFineShape)
		dirsFine = self.encoderFn(dirsFine, self.lDir)

		# keep track of our gradients
		with tf.GradientTape() as fineTape:
			# compute the predictions from the fine model
			rgbFine, sigmaFine = self.fineModel([raysFine, dirsFine])
			
			# render the image from the predicitons
			renderFine = self.renderImageDepth(rgb=rgbFine, sigma=sigmaFine, tVals=tValsFine)
			(imageFine, _, _) = renderFine

			# compute the photometric loss
			lossFine = self.lossFn(images, imageFine)

		# get the trainable variables from the coarse model and
		# apply back propagation
		tvCoarse = self.coarseModel.trainable_variables
		gradsCoarse = coarseTape.gradient(lossCoarse, tvCoarse)
		self.optimizerCoarse.apply_gradients(zip(gradsCoarse, 
			tvCoarse))

		# get the trainable variables from the coarse model and
		# apply back propagation
		tvFine = self.fineModel.trainable_variables
		gradsFine = fineTape.gradient(lossFine, tvFine)
		self.optimizerFine.apply_gradients(zip(gradsFine, tvFine))
		psnr = tf.image.psnr(images, imageFine, max_val=1.0)

		# compute the loss and psnr metrics
		self.lossTracker.update_state(lossFine)
		self.psnrMetric.update_state(psnr)

		# return the loss and psnr metrics
		return {"loss": self.lossTracker.result(),
			"psnr": self.psnrMetric.result()}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def test_step(self, inputs):
		# get the images and the rays
		(elements, images) = inputs
		(raysOriCoarse, raysDirCoarse, tValsCoarse) = elements

		# generate the coarse rays
		raysCoarse = (raysOriCoarse[..., None, :] + 
			(raysDirCoarse[..., None, :] * tValsCoarse[..., None]))

		# positional encode the rays and dirs
		raysCoarse = self.encoderFn(raysCoarse, self.lxyz)
		dirCoarseShape = tf.shape(raysCoarse[..., :3])
		dirsCoarse = tf.broadcast_to(raysDirCoarse[..., None, :],
			shape=dirCoarseShape)
		dirsCoarse = self.encoderFn(dirsCoarse, self.lDir)

		# compute the predictions from the coarse model
		(rgbCoarse, sigmaCoarse) = self.coarseModel([raysCoarse,
			dirsCoarse])
		
		# render the image from the predicitons
		renderCoarse = self.renderImageDepth(rgb=rgbCoarse,
			sigma=sigmaCoarse, tVals=tValsCoarse)
		(_, _, weightsCoarse) = renderCoarse

		# compute the middle values of t vals
		tValsCoarseMid = (0.5 * 
			(tValsCoarse[..., 1:] + tValsCoarse[..., :-1]))

		# apply hierarchical sampling and get the t vals for the fine
		# model
		tValsFine = self.samplePdf(tValsMid=tValsCoarseMid,
			weights=weightsCoarse, nF=self.nF)
		tValsFine = tf.sort(
			tf.concat([tValsCoarse, tValsFine], axis=-1), axis=-1)

		# build the fine rays and positional encode it
		raysFine = (raysOriCoarse[..., None, :] + 
			(raysDirCoarse[..., None, :] * tValsFine[..., None]))
		raysFine = self.encoderFn(raysFine, self.lxyz)
		
		# build the fine direcitons and positional encode it
		dirsFineShape = tf.shape(raysFine[..., :3])
		dirsFine = tf.broadcast_to(raysDirCoarse[..., None, :],
			shape=dirsFineShape)
		dirsFine = self.encoderFn(dirsFine, self.lDir)

		# compute the predictions from the fine model
		rgbFine, sigmaFine = self.fineModel([raysFine, dirsFine])
		
		# render the image from the predicitons
		renderFine = self.renderImageDepth(rgb=rgbFine,
			sigma=sigmaFine, tVals=tValsFine)
		(imageFine, _, _) = renderFine

		# compute the photometric loss and psnr
		lossFine = self.lossFn(images, imageFine)
		psnr = tf.image.psnr(images, imageFine, max_val=1.0)

		# compute the loss and psnr metrics
		self.lossTracker.update_state(lossFine)
		self.psnrMetric.update_state(psnr)

		# return the loss and psnr metrics
		return {"loss": self.lossTracker.result(),
			"psnr": self.psnrMetric.result()}

	@property
	def metrics(self):
		# return the loss and psnr tracker
		return [self.lossTracker, self.psnrMetric]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def render_image_depth(rgb, sigma, tVals):
    #Converting predicted RGB and Sigma Values to RGB image and Depth Map using Volumetric Rendering 
	sigma = sigma[..., 0]
	
	delta = tVals[..., 1:] - tVals[..., :-1]
	deltaShape = [BATCH_SIZE, IMAGE_HEIGHT, IMAGE_WIDTH, 1]
	delta = tf.concat([delta, tf.broadcast_to([1e10], shape=deltaShape)], axis=-1)

	alpha = 1.0 - tf.exp(-sigma * delta)
	expTerm = 1.0 - alpha
	epsilon = 1e-10
	transmittance = tf.math.cumprod(expTerm + epsilon, axis=-1, exclusive=True)
	weights = alpha * transmittance
	
	image = tf.reduce_sum(weights[..., None] * rgb, axis=-2)
	depth = tf.reduce_sum(weights * tVals, axis=-1)
	return (image, depth, weights)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
def encoder_fn(p, L):
    # build the list of positional encodings
    gamma = [p]
    # iterate over the number of dimensions in time
    for i in range(L):
        # insert sine and cosine of the product of current dimension
        # and the position vector
        gamma.append(tf.sin((2.0 ** i) * p))
        gamma.append(tf.cos((2.0 ** i) * p))
    
    # concatenate the positional encodings into a positional vector
    gamma = tf.concat(gamma, axis=-1)
    # return the positional encoding vector
    return gamma
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def sample_pdf(tValsMid, weights, nF):
	weights += 1e-5
	pdf = weights / tf.reduce_sum(weights, axis=-1, keepdims=True)
	cdf = tf.cumsum(pdf, axis=-1)
	cdf = tf.concat([tf.zeros_like(cdf[..., :1]), cdf], axis=-1)

	uShape = [BATCH_SIZE, IMAGE_HEIGHT, IMAGE_WIDTH, nF]
	u = tf.random.uniform(shape=uShape)

	indices = tf.searchsorted(cdf, u, side="right")
	below = tf.maximum(0, indices-1)
	above = tf.minimum(cdf.shape[-1]-1, indices)
	indicesG = tf.stack([below, above], axis=-1)

	cdfG = tf.gather(cdf, indicesG, axis=-1,batch_dims=len(indicesG.shape)-2)
	tValsMidG = tf.gather(tValsMid, indicesG, axis=-1, batch_dims=len(indicesG.shape)-2)

	denom = cdfG[..., 1] - cdfG[..., 0]
	denom = tf.where(denom < 1e-5, tf.ones_like(denom), denom)
	t = (u - cdfG[..., 0]) / denom
	samples = (tValsMidG[..., 0] + t * (tValsMidG[..., 1] - tValsMidG[..., 0]))
	
	return samples

Train Model

Compiling the Model (optimizer used : Adam, Loss Function : Mean Squared Error)

nerfModel.compile(optimizerCoarse=Adam(),optimizerFine=Adam(),lossFn=MeanSquaredError())

모델 학습 및 평가 결과를 추적하도록 모니터 콜백을 교육

trainMonitorCallback = get_train_monitor(testDs=testdata,encoderFn=encoder_fn, lxyz=config.DIMS_XYZ, lDir=config.DIMS_DIR, imagePath=config.IMAGE_PATH)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def get_train_monitor(testDs, encoderFn, lxyz, lDir, ):
	(tElements, tImages) = next(iter(testDs))
	(tRaysOriCoarse, tRaysDirCoarse, tTvalsCoarse) = tElements

	tRaysCoarse = (tRaysOriCoarse[..., None, :] + (tRaysDirCoarse[..., None, :] * tTvalsCoarse[..., None]))

	# Coarse Ray -> Positional Encoding of the Coarse Ray and Direction vector
	tRaysCoarse = encoderFn(tRaysCoarse, lxyz)
	tDirsCoarseShape = tf.shape(tRaysCoarse[..., :3])
	tDirsCoarse = tf.broadcast_to(tRaysDirCoarse[..., None, :], shape=tDirsCoarseShape)
	tDirsCoarse = encoderFn(tDirsCoarse, lDir)

	class TrainMonitor(Callback):
		def on_epoch_end(self, epoch, logs=None):
			(tRgbCoarse, tSigmaCoarse) = self.model.coarseModel.predict([tRaysCoarse, tDirsCoarse])
			
			tRenderCoarse = self.model.renderImageDepth(rgb=tRgbCoarse, sigma=tSigmaCoarse, tVals=tTvalsCoarse)
			(tImageCoarse, _, tWeightsCoarse) = tRenderCoarse

			# Mid Values of the tvals
			tTvalsCoarseMid = (0.5 * (tTvalsCoarse[..., 1:] + tTvalsCoarse[..., :-1]))

			# Applying hierarchical sampling and get the t vals for the fine model
			tTvalsFine = self.model.samplePdf(tValsMid=tTvalsCoarseMid, weights=tWeightsCoarse, nF=self.model.nF)
			tTvalsFine = tf.sort(tf.concat([tTvalsCoarse, tTvalsFine], axis=-1),axis=-1)

			# Building the fine rays and encoding 
			tRaysFine = (tRaysOriCoarse[..., None, :] + (tRaysDirCoarse[..., None, :] * tTvalsFine[..., None]))
			tRaysFine = self.model.encoderFn(tRaysFine, lxyz)
			
			# Builiding the directions and encoding
			tDirsFineShape = tf.shape(tRaysFine[..., :3])
			tDirsFine = tf.broadcast_to(tRaysDirCoarse[..., None, :], shape=tDirsFineShape)
			tDirsFine = self.model.encoderFn(tDirsFine, lDir)

			# compute the fine model prediction
			tRgbFine, tSigmaFine = self.model.fineModel.predict([tRaysFine, tDirsFine])
			
			# render the image from the model prediction
			tRenderFine = self.model.renderImageDepth(rgb=tRgbFine, sigma=tSigmaFine, tVals=tTvalsFine)
			(tImageFine, tDepthFine, _) = tRenderFine

			# plot the coarse image, fine image, fine depth map and target image
			(_, ax) = plt.subplots(nrows=1, ncols=4, figsize=(10, 10))
			ax[0].imshow(array_to_img(tImageCoarse[0]))
			ax[0].set_title(f"Corase Image")

			ax[1].imshow(array_to_img(tImageFine[0]))
			ax[1].set_title(f"Fine Image")

			ax[2].imshow(array_to_img(tDepthFine[0, ..., None]), 
				cmap="inferno")
			ax[2].set_title(f"Fine Depth Image")

			ax[3].imshow(array_to_img(tImages[0]))
			ax[3].set_title(f"Real Image")
			
			plt.savefig(f"{imagePath}/{epoch:03d}.png")
			plt.close()

			tf.keras.backend.clear_session()
	
	trainMonitor = TrainMonitor()
	return trainMonitor

NERF Model training

nerfModel.fit(traindata, steps_per_epoch=config.STEPS_PER_EPOCH, validation_data=valdata, validation_steps=config.VALIDATION_STEPS, epochs=config.EPOCHS, callbacks=[trainMonitorCallback],)

Saving the model parameters

nerfModel.coarseModel.save(config.COARSE_PATH) nerfModel.fineModel.save(config.FINE_PATH)

참고

  • https://keras.io/examples/vision/nerf/
//